From 18d8ea112c07154b7f5e377d7fdd7036eef91970 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Sun, 25 Oct 2015 18:36:52 +0800 Subject: [PATCH 01/25] ast, optimiser: change ast visitor API and implement cloner and binder --- ast/ast.go | 3 +- ast/cloner.go | 168 ++++++++++ ast/cloner_test.go | 40 +++ ast/ddl.go | 302 +++++++++-------- ast/dml.go | 619 ++++++++++++++++++++++------------- ast/expressions.go | 484 +++++++++++++++------------ ast/functions.go | 196 ++++++----- ast/misc.go | 208 +++++++----- ast/parser/parser.y | 178 +++++----- ast/parser/parser_test.go | 4 +- ast/parser/yy_parser.go | 19 ++ optimizer/infobinder.go | 441 +++++++++++++++++++++++++ optimizer/infobinder_test.go | 69 ++++ 13 files changed, 1882 insertions(+), 849 deletions(-) create mode 100644 ast/cloner.go create mode 100644 ast/cloner_test.go create mode 100644 ast/parser/yy_parser.go create mode 100644 optimizer/infobinder.go create mode 100644 optimizer/infobinder_test.go diff --git a/ast/ast.go b/ast/ast.go index 0e9ab19893..dabc4fa912 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -95,8 +95,7 @@ type Visitor interface { // VisitEnter is called before children nodes is visited. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. - // ok returns false to stop visiting. - Enter(n Node) (skipChildren bool, ok bool) + Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) diff --git a/ast/cloner.go b/ast/cloner.go new file mode 100644 index 0000000000..f08ef09476 --- /dev/null +++ b/ast/cloner.go @@ -0,0 +1,168 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +// Cloner is a ast visitor that clones a node. +type Cloner struct { +} + +// Enter implements Visitor Enter interface. +func (c *Cloner) Enter(node Node) (Node, bool) { + return cloneStruct(node), false +} + +// Leave implements Visitor Leave interface. +func (c *Cloner) Leave(in Node) (out Node, ok bool) { + return in, true +} + +// cloneStruct clone a node's struct value, if the struct has slice +// the cloned value should make a new slice and copy old slice to new slice. +func cloneStruct(in Node) (out Node) { + switch v := in.(type) { + case *ValueExpr: + nv := *v + out = &nv + case *BetweenExpr: + nv := *v + out = &nv + case *BinaryOperationExpr: + nv := *v + out = &nv + case *WhenClause: + nv := *v + out = &nv + case *CaseExpr: + nv := *v + nv.WhenClauses = make([]*WhenClause, len(v.WhenClauses)) + copy(nv.WhenClauses, v.WhenClauses) + out = &nv + case *SubqueryExpr: + nv := *v + out = &nv + case *CompareSubqueryExpr: + nv := *v + out = &nv + case *ColumnName: + nv := *v + out = &nv + case *DefaultExpr: + nv := *v + out = &nv + case *IdentifierExpr: + nv := *v + out = &nv + case *ExistsSubqueryExpr: + nv := *v + out = &nv + case *PatternInExpr: + nv := *v + nv.List = make([]ExprNode, len(v.List)) + copy(nv.List, v.List) + out = &nv + case *IsNullExpr: + nv := *v + out = &nv + case *IsTruthExpr: + nv := *v + out = &nv + case *PatternLikeExpr: + nv := *v + out = &nv + case *ParamMarkerExpr: + nv := *v + out = &nv + case *ParenthesesExpr: + nv := *v + out = &nv + case *PositionExpr: + nv := *v + out = &nv + case *PatternRegexpExpr: + nv := *v + out = &nv + case *RowExpr: + nv := *v + nv.Values = make([]ExprNode, len(v.Values)) + copy(nv.Values, v.Values) + out = &nv + case *UnaryOperationExpr: + nv := *v + out = &nv + case *ValuesExpr: + nv := *v + out = &nv + case *VariableExpr: + nv := *v + out = &nv + case *Join: + nv := *v + out = &nv + case *TableName: + nv := *v + out = &nv + case *TableSource: + nv := *v + out = &nv + case *OnCondition: + nv := *v + out = &nv + case *WildCardField: + nv := *v + out = &nv + case *SelectField: + nv := *v + out = &nv + case *FieldList: + nv := *v + nv.Fields = make([]*SelectField, len(v.Fields)) + copy(nv.Fields, v.Fields) + out = &nv + case *TableRefsClause: + nv := *v + out = &nv + case *ByItem: + nv := *v + out = &nv + case *GroupByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *HavingClause: + nv := *v + out = &nv + case *OrderByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *SelectStmt: + nv := *v + out = &nv + case *UnionClause: + nv := *v + out = &nv + case *UnionStmt: + nv := *v + nv.Selects = make([]*SelectStmt, len(v.Selects)) + copy(nv.Selects, v.Selects) + out = &nv + default: + // We currently only handle expression and select statement. + // Will add more when we need to. + panic("unknown ast Node type") + } + return +} diff --git a/ast/cloner_test.go b/ast/cloner_test.go new file mode 100644 index 0000000000..186ee824a7 --- /dev/null +++ b/ast/cloner_test.go @@ -0,0 +1,40 @@ +package ast + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/parser/opcode" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testClonerSuite{}) + +type testClonerSuite struct { +} + +func (ts *testClonerSuite) TestCloner(c *C) { + cloner := &Cloner{} + + a := &UnaryOperationExpr{ + Op: opcode.Not, + V: &UnaryOperationExpr{V: &ValueExpr{Val: true}}, + } + + b, ok := a.Accept(cloner) + c.Assert(ok, IsTrue) + a1 := a.V + b1 := b.(*UnaryOperationExpr).V + c.Assert(a1, Not(Equals), b1) + a2 := a1.(*UnaryOperationExpr).V + b2 := b1.(*UnaryOperationExpr).V + c.Assert(a2, Not(Equals), b2) + a3 := a2.(*ValueExpr) + b3 := b2.(*ValueExpr) + c.Assert(a3, Not(Equals), b3) + c.Assert(a3.Val, Equals, true) + c.Assert(b3.Val, Equals, true) +} diff --git a/ast/ddl.go b/ast/ddl.go index 0e43b69552..37a5d99c88 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -70,11 +70,13 @@ type CreateDatabaseStmt struct { } // Accept implements Node Accept interface. -func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (nod *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cd) + nod = newNod.(*CreateDatabaseStmt) + return v.Leave(nod) } // DropDatabaseStmt is a statement to drop a database and all tables in the database. @@ -87,11 +89,13 @@ type DropDatabaseStmt struct { } // Accept implements Node Accept interface. -func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dd); skipChildren { - return dd, ok +func (nod *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(dd) + nod = newNod.(*DropDatabaseStmt) + return v.Leave(nod) } // IndexColName is used for parsing index column name from SQL. @@ -103,16 +107,18 @@ type IndexColName struct { } // Accept implements Node Accept interface. -func (ic *IndexColName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ic); skipChildren { - return ic, ok +func (nod *IndexColName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ic.Column.Accept(v) + nod = newNod.(*IndexColName) + node, ok := nod.Column.Accept(v) if !ok { - return ic, false + return nod, false } - ic.Column = node.(*ColumnName) - return v.Leave(ic) + nod.Column = node.(*ColumnName) + return v.Leave(nod) } // ReferenceDef is used for parsing foreign key reference option from SQL. @@ -125,23 +131,25 @@ type ReferenceDef struct { } // Accept implements Node Accept interface. -func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rd); skipChildren { - return rd, ok +func (nod *ReferenceDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := rd.Table.Accept(v) + nod = newNod.(*ReferenceDef) + node, ok := nod.Table.Accept(v) if !ok { - return rd, false + return nod, false } - rd.Table = node.(*TableName) - for i, val := range rd.IndexColNames { + nod.Table = node.(*TableName) + for i, val := range nod.IndexColNames { node, ok = val.Accept(v) if !ok { - return rd, false + return nod, false } - rd.IndexColNames[i] = node.(*IndexColName) + nod.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(rd) + return v.Leave(nod) } // ColumnOptionType is the type for ColumnOption. @@ -175,18 +183,20 @@ type ColumnOption struct { } // Accept implements Node Accept interface. -func (co *ColumnOption) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(co); skipChildren { - return co, ok +func (nod *ColumnOption) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if co.Expr != nil { - node, ok := co.Expr.Accept(v) + nod = newNod.(*ColumnOption) + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) if !ok { - return co, false + return nod, false } - co.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) } - return v.Leave(co) + return v.Leave(nod) } // ConstraintType is the type for Constraint. @@ -220,25 +230,27 @@ type Constraint struct { } // Accept implements Node Accept interface. -func (tc *Constraint) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tc); skipChildren { - return tc, ok +func (nod *Constraint) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range tc.Keys { + nod = newNod.(*Constraint) + for i, val := range nod.Keys { node, ok := val.Accept(v) if !ok { - return tc, false + return nod, false } - tc.Keys[i] = node.(*IndexColName) + nod.Keys[i] = node.(*IndexColName) } - if tc.Refer != nil { - node, ok := tc.Refer.Accept(v) + if nod.Refer != nil { + node, ok := nod.Refer.Accept(v) if !ok { - return tc, false + return nod, false } - tc.Refer = node.(*ReferenceDef) + nod.Refer = node.(*ReferenceDef) } - return v.Leave(tc) + return v.Leave(nod) } // ColumnDef is used for parsing column definition from SQL. @@ -251,23 +263,25 @@ type ColumnDef struct { } // Accept implements Node Accept interface. -func (cd *ColumnDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (nod *ColumnDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := cd.Name.Accept(v) + nod = newNod.(*ColumnDef) + node, ok := nod.Name.Accept(v) if !ok { - return cd, false + return nod, false } - cd.Name = node.(*ColumnName) - for i, val := range cd.Options { + nod.Name = node.(*ColumnName) + for i, val := range nod.Options { node, ok := val.Accept(v) if !ok { - return cd, false + return nod, false } - cd.Options[i] = node.(*ColumnOption) + nod.Options[i] = node.(*ColumnOption) } - return v.Leave(cd) + return v.Leave(nod) } // CreateTableStmt is a statement to create a table. @@ -283,30 +297,32 @@ type CreateTableStmt struct { } // Accept implements Node Accept interface. -func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ct); skipChildren { - return ct, ok +func (nod *CreateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ct.Table.Accept(v) + nod = newNod.(*CreateTableStmt) + node, ok := nod.Table.Accept(v) if !ok { - return ct, false + return nod, false } - ct.Table = node.(*TableName) - for i, val := range ct.Cols { + nod.Table = node.(*TableName) + for i, val := range nod.Cols { node, ok = val.Accept(v) if !ok { - return ct, false + return nod, false } - ct.Cols[i] = node.(*ColumnDef) + nod.Cols[i] = node.(*ColumnDef) } - for i, val := range ct.Constraints { + for i, val := range nod.Constraints { node, ok = val.Accept(v) if !ok { - return ct, false + return nod, false } - ct.Constraints[i] = node.(*Constraint) + nod.Constraints[i] = node.(*Constraint) } - return v.Leave(ct) + return v.Leave(nod) } // DropTableStmt is a statement to drop one or more tables. @@ -319,18 +335,20 @@ type DropTableStmt struct { } // Accept implements Node Accept interface. -func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dt); skipChildren { - return dt, ok +func (nod *DropTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range dt.Tables { + nod = newNod.(*DropTableStmt) + for i, val := range nod.Tables { node, ok := val.Accept(v) if !ok { - return dt, false + return nod, false } - dt.Tables[i] = node.(*TableName) + nod.Tables[i] = node.(*TableName) } - return v.Leave(dt) + return v.Leave(nod) } // CreateIndexStmt is a statement to create an index. @@ -345,23 +363,25 @@ type CreateIndexStmt struct { } // Accept implements Node Accept interface. -func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ci); skipChildren { - return ci, ok +func (nod *CreateIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ci.Table.Accept(v) + nod = newNod.(*CreateIndexStmt) + node, ok := nod.Table.Accept(v) if !ok { - return ci, false + return nod, false } - ci.Table = node.(*TableName) - for i, val := range ci.IndexColNames { + nod.Table = node.(*TableName) + for i, val := range nod.IndexColNames { node, ok = val.Accept(v) if !ok { - return ci, false + return nod, false } - ci.IndexColNames[i] = node.(*IndexColName) + nod.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(ci) + return v.Leave(nod) } // DropIndexStmt is a statement to drop the index. @@ -375,16 +395,18 @@ type DropIndexStmt struct { } // Accept implements Node Accept interface. -func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(di); skipChildren { - return di, ok +func (nod *DropIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := di.Table.Accept(v) + nod = newNod.(*DropIndexStmt) + node, ok := nod.Table.Accept(v) if !ok { - return di, false + return nod, false } - di.Table = node.(*TableName) - return v.Leave(di) + nod.Table = node.(*TableName) + return v.Leave(nod) } // TableOptionType is the type for TableOption @@ -435,18 +457,20 @@ type ColumnPosition struct { } // Accept implements Node Accept interface. -func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cp); skipChildren { - return cp, ok +func (nod *ColumnPosition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if cp.RelativeColumn != nil { - node, ok := cp.RelativeColumn.Accept(v) + nod = newNod.(*ColumnPosition) + if nod.RelativeColumn != nil { + node, ok := nod.RelativeColumn.Accept(v) if !ok { - return cp, false + return nod, false } - cp.RelativeColumn = node.(*ColumnName) + nod.RelativeColumn = node.(*ColumnName) } - return v.Leave(cp) + return v.Leave(nod) } // AlterTableType is the type for AlterTableSpec. @@ -479,39 +503,41 @@ type AlterTableSpec struct { } // Accept implements Node Accept interface. -func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (nod *AlterTableSpec) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if as.Constraint != nil { - node, ok := as.Constraint.Accept(v) + nod = newNod.(*AlterTableSpec) + if nod.Constraint != nil { + node, ok := nod.Constraint.Accept(v) if !ok { - return as, false + return nod, false } - as.Constraint = node.(*Constraint) + nod.Constraint = node.(*Constraint) } - if as.Column != nil { - node, ok := as.Column.Accept(v) + if nod.Column != nil { + node, ok := nod.Column.Accept(v) if !ok { - return as, false + return nod, false } - as.Column = node.(*ColumnDef) + nod.Column = node.(*ColumnDef) } - if as.ColumnName != nil { - node, ok := as.ColumnName.Accept(v) + if nod.ColumnName != nil { + node, ok := nod.ColumnName.Accept(v) if !ok { - return as, false + return nod, false } - as.ColumnName = node.(*ColumnName) + nod.ColumnName = node.(*ColumnName) } - if as.Position != nil { - node, ok := as.Position.Accept(v) + if nod.Position != nil { + node, ok := nod.Position.Accept(v) if !ok { - return as, false + return nod, false } - as.Position = node.(*ColumnPosition) + nod.Position = node.(*ColumnPosition) } - return v.Leave(as) + return v.Leave(nod) } // AlterTableStmt is a statement to change the structure of a table. @@ -524,23 +550,25 @@ type AlterTableStmt struct { } // Accept implements Node Accept interface. -func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(at); skipChildren { - return at, ok +func (nod *AlterTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := at.Table.Accept(v) + nod = newNod.(*AlterTableStmt) + node, ok := nod.Table.Accept(v) if !ok { - return at, false + return nod, false } - at.Table = node.(*TableName) - for i, val := range at.Specs { + nod.Table = node.(*TableName) + for i, val := range nod.Specs { node, ok = val.Accept(v) if !ok { - return at, false + return nod, false } - at.Specs[i] = node.(*AlterTableSpec) + nod.Specs[i] = node.(*AlterTableSpec) } - return v.Leave(at) + return v.Leave(nod) } // TruncateTableStmt is a statement to empty a table completely. @@ -552,14 +580,16 @@ type TruncateTableStmt struct { } // Accept implements Node Accept interface. -func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (nod *TruncateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Table.Accept(v) + nod = newNod.(*TruncateTableStmt) + node, ok := nod.Table.Accept(v) if !ok { - return ts, false + return nod, false } - ts.Table = node.(*TableName) - return v.Leave(ts) + nod.Table = node.(*TableName) + return v.Leave(nod) } diff --git a/ast/dml.go b/ast/dml.go index b8ada81d99..f7f11421bd 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -55,34 +55,36 @@ type Join struct { // Tp represents join type. Tp JoinType // On represents join on condition. - On ExprNode + On *OnCondition } // Accept implements Node Accept interface. -func (j *Join) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(j); skipChildren { - return j, ok +func (nod *Join) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := j.Left.Accept(v) + nod = newNod.(*Join) + node, ok := nod.Left.Accept(v) if !ok { - return j, false + return nod, false } - j.Left = node.(ResultSetNode) - if j.Right != nil { - node, ok = j.Right.Accept(v) + nod.Left = node.(ResultSetNode) + if nod.Right != nil { + node, ok = nod.Right.Accept(v) if !ok { - return j, false + return nod, false } - j.Right = node.(ResultSetNode) + nod.Right = node.(ResultSetNode) } - if j.On != nil { - node, ok = j.On.Accept(v) + if nod.On != nil { + node, ok = nod.On.Accept(v) if !ok { - return j, false + return nod, false } - j.On = node.(ExprNode) + nod.On = node.(*OnCondition) } - return v.Leave(j) + return v.Leave(nod) } // TableName represents a table name. @@ -98,11 +100,13 @@ type TableName struct { } // Accept implements Node Accept interface. -func (tr *TableName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tr); skipChildren { - return tr, ok +func (nod *TableName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(tr) + nod = newNod.(*TableName) + return v.Leave(nod) } // TableSource represents table source with a name. @@ -118,47 +122,50 @@ type TableSource struct { } // Accept implements Node Accept interface. -func (ts *TableSource) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (nod *TableSource) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Source.Accept(v) + nod = newNod.(*TableSource) + node, ok := nod.Source.Accept(v) if !ok { - return ts, false + return nod, false } - ts.Source = node.(ResultSetNode) - return v.Leave(ts) + nod.Source = node.(ResultSetNode) + return v.Leave(nod) } -// SetResultFields implements ResultSet interface. -func (ts *TableSource) SetResultFields(rfs []*ResultField) { - ts.Source.SetResultFields(rfs) -} - -// GetResultFields implements ResultSet interface. -func (ts *TableSource) GetResultFields() []*ResultField { - return ts.Source.GetResultFields() -} - -// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. -type UnionClause struct { +// OnCondition represetns JOIN on condition. +type OnCondition struct { node - Distinct bool - Select *SelectStmt + Expr ExprNode } // Accept implements Node Accept interface. -func (uc *UnionClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(uc); skipChildren { - return uc, ok +func (nod *OnCondition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := uc.Select.Accept(v) + nod = newNod.(*OnCondition) + node, ok := nod.Expr.Accept(v) if !ok { - return uc, false + return nod, false } - uc.Select = node.(*SelectStmt) - return v.Leave(uc) + nod.Expr = node.(ExprNode) + return v.Leave(nod) +} + +// SetResultFields implements ResultSet interface. +func (nod *TableSource) SetResultFields(rfs []*ResultField) { + nod.Source.SetResultFields(rfs) +} + +// GetResultFields implements ResultSet interface. +func (nod *TableSource) GetResultFields() []*ResultField { + return nod.Source.GetResultFields() } // SelectLockType is the lock type for SelectStmt. @@ -175,22 +182,18 @@ const ( type WildCardField struct { node - Table *TableName + Table model.CIStr + Schema model.CIStr } // Accept implements Node Accept interface. -func (wf *WildCardField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(wf); skipChildren { - return wf, ok +func (nod *WildCardField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if wf.Table != nil { - node, ok := wf.Table.Accept(v) - if !ok { - return wf, false - } - wf.Table = node.(*TableName) - } - return v.Leave(wf) + nod = newNod.(*WildCardField) + return v.Leave(nod) } // SelectField represents fields in select statement. @@ -208,22 +211,70 @@ type SelectField struct { } // Accept implements Node Accept interface. -func (sf *SelectField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (nod *SelectField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if sf.Expr != nil { - node, ok := sf.Expr.Accept(v) + nod = newNod.(*SelectField) + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) if !ok { - return sf, false + return nod, false } - sf.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) } - return v.Leave(sf) + return v.Leave(nod) } -// OrderByItem represents a single order by item. -type OrderByItem struct { +// FieldList represents field list in select statement. +type FieldList struct { + node + + Fields []*SelectField +} + +// Accept implements Node Accept interface. +func (nod *FieldList) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*FieldList) + for i, val := range nod.Fields { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Fields[i] = node.(*SelectField) + } + return v.Leave(nod) +} + +// TableRefsClause represents table references clause in dml statement. +type TableRefsClause struct { + node + + TableRefs *Join +} + +// Accept implements Node Accept interface. +func (nod *TableRefsClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*TableRefsClause) + node, ok := nod.TableRefs.Accept(v) + if !ok { + return nod, false + } + nod.TableRefs = node.(*Join) + return v.Leave(nod) +} + +// ByItem represents an item in order by or group by. +type ByItem struct { node Expr ExprNode @@ -231,131 +282,236 @@ type OrderByItem struct { } // Accept implements Node Accept interface. -func (ob *OrderByItem) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ob); skipChildren { - return ob, ok +func (nod *ByItem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ob.Expr.Accept(v) + nod = newNod.(*ByItem) + node, ok := nod.Expr.Accept(v) if !ok { - return ob, false + return nod, false } - ob.Expr = node.(ExprNode) - return v.Leave(ob) + nod.Expr = node.(ExprNode) + return v.Leave(nod) +} + +// GroupByClause represents group by clause. +type GroupByClause struct { + node + Items []*ByItem +} + +// Accept implements Node Accept interface. +func (nod *GroupByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*GroupByClause) + for i, val := range nod.Items { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Items[i] = node.(*ByItem) + } + return v.Leave(nod) +} + +// HavingClause represents having clause. +type HavingClause struct { + node + Expr ExprNode +} + +// Accept implements Node Accept interface. +func (nod *HavingClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*HavingClause) + node, ok := nod.Expr.Accept(v) + if !ok { + return nod, false + } + nod.Expr = node.(ExprNode) + return v.Leave(nod) +} + +// OrderByClause represents order by clause. +type OrderByClause struct { + node + Items []*ByItem + ForUnion bool +} + +// Accept implements Node Accept interface. +func (nod *OrderByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*OrderByClause) + for i, val := range nod.Items { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Items[i] = node.(*ByItem) + } + return v.Leave(nod) } // SelectStmt represents the select query node. +// See: https://dev.mysql.com/doc/refman/5.7/en/select.html type SelectStmt struct { dmlNode resultSetNode // Distinct represents if the select has distinct option. Distinct bool - // Fields is the select expression list. - Fields []*SelectField // From is the from clause of the query. - From *Join + From *TableRefsClause // Where is the where clause in select statement. Where ExprNode + // Fields is the select expression list. + Fields *FieldList // GroupBy is the group by expression list. - GroupBy []ExprNode + GroupBy *GroupByClause // Having is the having condition. - Having ExprNode - // OrderBy is the odering expression list. - OrderBy []*OrderByItem + Having *HavingClause + // OrderBy is the ordering expression list. + OrderBy *OrderByClause // Limit is the limit clause. Limit *Limit // Lock is the lock type LockTp SelectLockType - - // Union clauses. - Unions []*UnionClause - // Order by for union select. - UnionOrderBy []*OrderByItem - // Limit for union select. - UnionLimit *Limit } // Accept implements Node Accept interface. -func (sn *SelectStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sn); skipChildren { - return sn, ok +func (nod *SelectStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range sn.Fields { - node, ok := val.Accept(v) + nod = newNod.(*SelectStmt) + + if nod.From != nil { + node, ok := nod.From.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Fields[i] = node.(*SelectField) - } - if sn.From != nil { - node, ok := sn.From.Accept(v) - if !ok { - return sn, false - } - sn.From = node.(*Join) + nod.From = node.(*TableRefsClause) } - if sn.Where != nil { - node, ok := sn.Where.Accept(v) + if nod.Where != nil { + node, ok := nod.Where.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - for i, val := range sn.GroupBy { - node, ok := val.Accept(v) + if nod.Fields != nil { + node, ok := nod.Fields.Accept(v) if !ok { - return sn, false + return nod, false } - sn.GroupBy[i] = node.(ExprNode) - } - if sn.Having != nil { - node, ok := sn.Having.Accept(v) - if !ok { - return sn, false - } - sn.Having = node.(ExprNode) + nod.Fields = node.(*FieldList) } - for i, val := range sn.OrderBy { - node, ok := val.Accept(v) + if nod.GroupBy != nil { + node, ok := nod.GroupBy.Accept(v) if !ok { - return sn, false + return nod, false } - sn.OrderBy[i] = node.(*OrderByItem) + nod.GroupBy = node.(*GroupByClause) } - if sn.Limit != nil { - node, ok := sn.Limit.Accept(v) + if nod.Having != nil { + node, ok := nod.Having.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Limit = node.(*Limit) + nod.Having = node.(*HavingClause) } - for i, val := range sn.Unions { + if nod.OrderBy != nil { + node, ok := nod.OrderBy.Accept(v) + if !ok { + return nod, false + } + nod.OrderBy = node.(*OrderByClause) + } + + if nod.Limit != nil { + node, ok := nod.Limit.Accept(v) + if !ok { + return nod, false + } + nod.Limit = node.(*Limit) + } + return v.Leave(nod) +} + +// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. +type UnionClause struct { + node + + Distinct bool + Select *SelectStmt +} + +// Accept implements Node Accept interface. +func (nod *UnionClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*UnionClause) + node, ok := nod.Select.Accept(v) + if !ok { + return nod, false + } + nod.Select = node.(*SelectStmt) + return v.Leave(nod) +} + +// UnionStmt represents "union statement" +// See: https://dev.mysql.com/doc/refman/5.7/en/union.html +type UnionStmt struct { + dmlNode + resultSetNode + + Distinct bool + Selects []*SelectStmt + Limit *Limit +} + +// Accept implements Node Accept interface. +func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*UnionStmt) + for i, val := range nod.Selects { node, ok := val.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Unions[i] = node.(*UnionClause) + nod.Selects[i] = node.(*SelectStmt) } - for i, val := range sn.UnionOrderBy { - node, ok := val.Accept(v) + if nod.Limit != nil { + node, ok := nod.Limit.Accept(v) if !ok { - return sn, false + return nod, false } - sn.UnionOrderBy[i] = node.(*OrderByItem) + nod.Limit = node.(*Limit) } - if sn.UnionLimit != nil { - node, ok := sn.UnionLimit.Accept(v) - if !ok { - return sn, false - } - sn.UnionLimit = node.(*Limit) - } - return v.Leave(sn) + return v.Leave(nod) } // Assignment is the expression for assignment, like a = 1. @@ -368,21 +524,23 @@ type Assignment struct { } // Accept implements Node Accept interface. -func (as *Assignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (nod *Assignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := as.Column.Accept(v) + nod = newNod.(*Assignment) + node, ok := nod.Column.Accept(v) if !ok { - return as, false + return nod, false } - as.Column = node.(*ColumnName) - node, ok = as.Expr.Accept(v) + nod.Column = node.(*ColumnName) + node, ok = nod.Expr.Accept(v) if !ok { - return as, false + return nod, false } - as.Expr = node.(ExprNode) - return v.Leave(as) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // Priority const values. @@ -399,9 +557,9 @@ const ( type InsertStmt struct { dmlNode + Table *TableRefsClause Columns []*ColumnName Lists [][]ExprNode - Table *TableName Setlist []*Assignment Priority int OnDuplicate []*Assignment @@ -409,48 +567,57 @@ type InsertStmt struct { } // Accept implements Node Accept interface. -func (in *InsertStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(in); skipChildren { - return in, ok +func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range in.Columns { + nod = newNod.(*InsertStmt) + + node, ok := nod.Table.Accept(v) + if !ok { + return nod, false + } + nod.Table = node.(*TableRefsClause) + + for i, val := range nod.Columns { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.Columns[i] = node.(*ColumnName) + nod.Columns[i] = node.(*ColumnName) } - for i, list := range in.Lists { + for i, list := range nod.Lists { for j, val := range list { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.Lists[i][j] = node.(ExprNode) + nod.Lists[i][j] = node.(ExprNode) } } - for i, val := range in.Setlist { + for i, val := range nod.Setlist { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.Setlist[i] = node.(*Assignment) + nod.Setlist[i] = node.(*Assignment) } - for i, val := range in.OnDuplicate { + for i, val := range nod.OnDuplicate { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.OnDuplicate[i] = node.(*Assignment) + nod.OnDuplicate[i] = node.(*Assignment) } - if in.Select != nil { - node, ok := in.Select.Accept(v) + if nod.Select != nil { + node, ok := nod.Select.Accept(v) if !ok { - return in, false + return nod, false } - in.Select = node.(*SelectStmt) + nod.Select = node.(*SelectStmt) } - return v.Leave(in) + return v.Leave(nod) } // DeleteStmt is a statement to delete rows from table. @@ -458,10 +625,12 @@ func (in *InsertStmt) Accept(v Visitor) (Node, bool) { type DeleteStmt struct { dmlNode - TableRefs *Join + // Used in both single table and multiple table delete statement. + TableRefs *TableRefsClause + // Only used in multiple table delete statement. Tables []*TableName Where ExprNode - Order []*OrderByItem + Order []*ByItem Limit *Limit LowPriority bool Ignore bool @@ -471,47 +640,49 @@ type DeleteStmt struct { } // Accept implements Node Accept interface. -func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(de); skipChildren { - return de, ok +func (nod *DeleteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } + nod = newNod.(*DeleteStmt) - node, ok := de.TableRefs.Accept(v) + node, ok := nod.TableRefs.Accept(v) if !ok { - return de, false + return nod, false } - de.TableRefs = node.(*Join) + nod.TableRefs = node.(*TableRefsClause) - for i, val := range de.Tables { + for i, val := range nod.Tables { node, ok = val.Accept(v) if !ok { - return de, false + return nod, false } - de.Tables[i] = node.(*TableName) + nod.Tables[i] = node.(*TableName) } - if de.Where != nil { - node, ok = de.Where.Accept(v) + if nod.Where != nil { + node, ok = nod.Where.Accept(v) if !ok { - return de, false + return nod, false } - de.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - for i, val := range de.Order { + for i, val := range nod.Order { node, ok = val.Accept(v) if !ok { - return de, false + return nod, false } - de.Order[i] = node.(*OrderByItem) + nod.Order[i] = node.(*ByItem) } - node, ok = de.Limit.Accept(v) + node, ok = nod.Limit.Accept(v) if !ok { - return de, false + return nod, false } - de.Limit = node.(*Limit) - return v.Leave(de) + nod.Limit = node.(*Limit) + return v.Leave(nod) } // UpdateStmt is a statement to update columns of existing rows in tables with new values. @@ -519,10 +690,10 @@ func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { type UpdateStmt struct { dmlNode - TableRefs *Join + TableRefs *TableRefsClause List []*Assignment Where ExprNode - Order []*OrderByItem + Order []*ByItem Limit *Limit LowPriority bool Ignore bool @@ -530,43 +701,45 @@ type UpdateStmt struct { } // Accept implements Node Accept interface. -func (up *UpdateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(up); skipChildren { - return up, ok +func (nod *UpdateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := up.TableRefs.Accept(v) + nod = newNod.(*UpdateStmt) + node, ok := nod.TableRefs.Accept(v) if !ok { - return up, false + return nod, false } - up.TableRefs = node.(*Join) - for i, val := range up.List { + nod.TableRefs = node.(*TableRefsClause) + for i, val := range nod.List { node, ok = val.Accept(v) if !ok { - return up, false + return nod, false } - up.List[i] = node.(*Assignment) + nod.List[i] = node.(*Assignment) } - if up.Where != nil { - node, ok = up.Where.Accept(v) + if nod.Where != nil { + node, ok = nod.Where.Accept(v) if !ok { - return up, false + return nod, false } - up.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - for i, val := range up.Order { + for i, val := range nod.Order { node, ok = val.Accept(v) if !ok { - return up, false + return nod, false } - up.Order[i] = node.(*OrderByItem) + nod.Order[i] = node.(*ByItem) } - node, ok = up.Limit.Accept(v) + node, ok = nod.Limit.Accept(v) if !ok { - return up, false + return nod, false } - up.Limit = node.(*Limit) - return v.Leave(up) + nod.Limit = node.(*Limit) + return v.Leave(nod) } // Limit is the limit clause. @@ -578,9 +751,11 @@ type Limit struct { } // Accept implements Node Accept interface. -func (l *Limit) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(l); skipChildren { - return l, ok +func (nod *Limit) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(l) + nod = newNod.(*Limit) + return v.Leave(nod) } diff --git a/ast/expressions.go b/ast/expressions.go index 558f1a019c..c99611ae13 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -52,16 +52,18 @@ type ValueExpr struct { } // IsStatic implements ExprNode interface. -func (val *ValueExpr) IsStatic() bool { +func (nod *ValueExpr) IsStatic() bool { return true } // Accept implements Node interface. -func (val *ValueExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(val); skipChildren { - return val, ok +func (nod *ValueExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(val) + nod = newNod.(*ValueExpr) + return v.Leave(nod) } // BetweenExpr is for "between and" or "not between and" expression. @@ -78,35 +80,37 @@ type BetweenExpr struct { } // Accept implements Node interface. -func (b *BetweenExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(b); skipChildren { - return b, ok +func (nod *BetweenExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } + nod = newNod.(*BetweenExpr) - node, ok := b.Expr.Accept(v) + node, ok := nod.Expr.Accept(v) if !ok { - return b, false + return nod, false } - b.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) - node, ok = b.Left.Accept(v) + node, ok = nod.Left.Accept(v) if !ok { - return b, false + return nod, false } - b.Left = node.(ExprNode) + nod.Left = node.(ExprNode) - node, ok = b.Right.Accept(v) + node, ok = nod.Right.Accept(v) if !ok { - return b, false + return nod, false } - b.Right = node.(ExprNode) + nod.Right = node.(ExprNode) - return v.Leave(b) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (b *BetweenExpr) IsStatic() bool { - return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic() +func (nod *BetweenExpr) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Left.IsStatic() && nod.Right.IsStatic() } // BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. @@ -121,29 +125,31 @@ type BinaryOperationExpr struct { } // Accept implements Node interface. -func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(o); skipChildren { - return o, ok +func (nod *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } + nod = newNod.(*BinaryOperationExpr) - node, ok := o.L.Accept(v) + node, ok := nod.L.Accept(v) if !ok { - return o, false + return nod, false } - o.L = node.(ExprNode) + nod.L = node.(ExprNode) - node, ok = o.R.Accept(v) + node, ok = nod.R.Accept(v) if !ok { - return o, false + return nod, false } - o.R = node.(ExprNode) + nod.R = node.(ExprNode) - return v.Leave(o) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (o *BinaryOperationExpr) IsStatic() bool { - return o.L.IsStatic() && o.R.IsStatic() +func (nod *BinaryOperationExpr) IsStatic() bool { + return nod.L.IsStatic() && nod.R.IsStatic() } // WhenClause is the when clause in Case expression for "when condition then result". @@ -156,27 +162,29 @@ type WhenClause struct { } // Accept implements Node Accept interface. -func (w *WhenClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(w); skipChildren { - return w, ok +func (nod *WhenClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := w.Expr.Accept(v) + nod = newNod.(*WhenClause) + node, ok := nod.Expr.Accept(v) if !ok { - return w, false + return nod, false } - w.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) - node, ok = w.Result.Accept(v) + node, ok = nod.Result.Accept(v) if !ok { - return w, false + return nod, false } - w.Result = node.(ExprNode) - return v.Leave(w) + nod.Result = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (w *WhenClause) IsStatic() bool { - return w.Expr.IsStatic() && w.Result.IsStatic() +func (nod *WhenClause) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Result.IsStatic() } // CaseExpr is the case expression. @@ -191,45 +199,47 @@ type CaseExpr struct { } // Accept implements Node Accept interface. -func (f *CaseExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (nod *CaseExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if f.Value != nil { - node, ok := f.Value.Accept(v) + nod = newNod.(*CaseExpr) + if nod.Value != nil { + node, ok := nod.Value.Accept(v) if !ok { - return f, false + return nod, false } - f.Value = node.(ExprNode) + nod.Value = node.(ExprNode) } - for i, val := range f.WhenClauses { + for i, val := range nod.WhenClauses { node, ok := val.Accept(v) if !ok { - return f, false + return nod, false } - f.WhenClauses[i] = node.(*WhenClause) + nod.WhenClauses[i] = node.(*WhenClause) } - if f.ElseClause != nil { - node, ok := f.ElseClause.Accept(v) + if nod.ElseClause != nil { + node, ok := nod.ElseClause.Accept(v) if !ok { - return f, false + return nod, false } - f.ElseClause = node.(ExprNode) + nod.ElseClause = node.(ExprNode) } - return v.Leave(f) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (f *CaseExpr) IsStatic() bool { - if f.Value != nil && !f.Value.IsStatic() { +func (nod *CaseExpr) IsStatic() bool { + if nod.Value != nil && !nod.Value.IsStatic() { return false } - for _, w := range f.WhenClauses { + for _, w := range nod.WhenClauses { if !w.IsStatic() { return false } } - if f.ElseClause != nil && !f.ElseClause.IsStatic() { + if nod.ElseClause != nil && !nod.ElseClause.IsStatic() { return false } return true @@ -243,26 +253,28 @@ type SubqueryExpr struct { } // Accept implements Node Accept interface. -func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sq); skipChildren { - return sq, ok +func (nod *SubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := sq.Query.Accept(v) + nod = newNod.(*SubqueryExpr) + node, ok := nod.Query.Accept(v) if !ok { - return sq, false + return nod, false } - sq.Query = node.(*SelectStmt) - return v.Leave(sq) + nod.Query = node.(*SelectStmt) + return v.Leave(nod) } // SetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) SetResultFields(rfs []*ResultField) { - sq.Query.SetResultFields(rfs) +func (nod *SubqueryExpr) SetResultFields(rfs []*ResultField) { + nod.Query.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) GetResultFields() []*ResultField { - return sq.Query.GetResultFields() +func (nod *SubqueryExpr) GetResultFields() []*ResultField { + return nod.Query.GetResultFields() } // CompareSubqueryExpr is the expression for "expr cmp (select ...)". @@ -282,21 +294,23 @@ type CompareSubqueryExpr struct { } // Accept implements Node Accept interface. -func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (nod *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := cs.L.Accept(v) + nod = newNod.(*CompareSubqueryExpr) + node, ok := nod.L.Accept(v) if !ok { - return cs, false + return nod, false } - cs.L = node.(ExprNode) - node, ok = cs.R.Accept(v) + nod.L = node.(ExprNode) + node, ok = nod.R.Accept(v) if !ok { - return cs, false + return nod, false } - cs.R = node.(*SubqueryExpr) - return v.Leave(cs) + nod.R = node.(*SubqueryExpr) + return v.Leave(nod) } // ColumnName represents column name. @@ -312,11 +326,13 @@ type ColumnName struct { } // Accept implements Node Accept interface. -func (cn *ColumnName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cn); skipChildren { - return cn, ok +func (nod *ColumnName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cn) + nod = newNod.(*ColumnName) + return v.Leave(nod) } // ColumnNameExpr represents a column name expression. @@ -328,16 +344,18 @@ type ColumnNameExpr struct { } // Accept implements Node Accept interface. -func (cr *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cr); skipChildren { - return cr, ok +func (nod *ColumnNameExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := cr.Name.Accept(v) + nod = newNod.(*ColumnNameExpr) + node, ok := nod.Name.Accept(v) if !ok { - return cr, false + return nod, false } - cr.Name = node.(*ColumnName) - return v.Leave(cr) + nod.Name = node.(*ColumnName) + return v.Leave(nod) } // DefaultExpr is the default expression using default value for a column. @@ -348,18 +366,20 @@ type DefaultExpr struct { } // Accept implements Node Accept interface. -func (d *DefaultExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(d); skipChildren { - return d, ok +func (nod *DefaultExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if d.Name != nil { - node, ok := d.Name.Accept(v) + nod = newNod.(*DefaultExpr) + if nod.Name != nil { + node, ok := nod.Name.Accept(v) if !ok { - return d, false + return nod, false } - d.Name = node.(*ColumnName) + nod.Name = node.(*ColumnName) } - return v.Leave(d) + return v.Leave(nod) } // IdentifierExpr represents an identifier expression. @@ -370,11 +390,13 @@ type IdentifierExpr struct { } // Accept implements Node Accept interface. -func (i *IdentifierExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(i); skipChildren { - return i, ok +func (nod *IdentifierExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(i) + nod = newNod.(*IdentifierExpr) + return v.Leave(nod) } // ExistsSubqueryExpr is the expression for "exists (select ...)". @@ -386,16 +408,18 @@ type ExistsSubqueryExpr struct { } // Accept implements Node Accept interface. -func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (nod *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Sel.Accept(v) + nod = newNod.(*ExistsSubqueryExpr) + node, ok := nod.Sel.Accept(v) if !ok { - return es, false + return nod, false } - es.Sel = node.(*SubqueryExpr) - return v.Leave(es) + nod.Sel = node.(*SubqueryExpr) + return v.Leave(nod) } // PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)". @@ -412,30 +436,32 @@ type PatternInExpr struct { } // Accept implements Node Accept interface. -func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pi); skipChildren { - return pi, ok +func (nod *PatternInExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := pi.Expr.Accept(v) + nod = newNod.(*PatternInExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return pi, false + return nod, false } - pi.Expr = node.(ExprNode) - for i, val := range pi.List { + nod.Expr = node.(ExprNode) + for i, val := range nod.List { node, ok = val.Accept(v) if !ok { - return pi, false + return nod, false } - pi.List[i] = node.(ExprNode) + nod.List[i] = node.(ExprNode) } - if pi.Sel != nil { - node, ok = pi.Sel.Accept(v) + if nod.Sel != nil { + node, ok = nod.Sel.Accept(v) if !ok { - return pi, false + return nod, false } - pi.Sel = node.(*SubqueryExpr) + nod.Sel = node.(*SubqueryExpr) } - return v.Leave(pi) + return v.Leave(nod) } // IsNullExpr is the expression for null check. @@ -448,21 +474,23 @@ type IsNullExpr struct { } // Accept implements Node Accept interface. -func (is *IsNullExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (nod *IsNullExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + nod = newNod.(*IsNullExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return is, false + return nod, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsNullExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (nod *IsNullExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // IsTruthExpr is the expression for true/false check. @@ -477,21 +505,23 @@ type IsTruthExpr struct { } // Accept implements Node Accept interface. -func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (nod *IsTruthExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + nod = newNod.(*IsTruthExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return is, false + return nod, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsTruthExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (nod *IsTruthExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // PatternLikeExpr is the expression for like operator, e.g, expr like "%123%" @@ -508,26 +538,28 @@ type PatternLikeExpr struct { } // Accept implements Node Accept interface. -func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pl); skipChildren { - return pl, ok +func (nod *PatternLikeExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := pl.Expr.Accept(v) + nod = newNod.(*PatternLikeExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return pl, false + return nod, false } - pl.Expr = node.(ExprNode) - node, ok = pl.Pattern.Accept(v) + nod.Expr = node.(ExprNode) + node, ok = nod.Pattern.Accept(v) if !ok { - return pl, false + return nod, false } - pl.Pattern = node.(ExprNode) - return v.Leave(pl) + nod.Pattern = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (pl *PatternLikeExpr) IsStatic() bool { - return pl.Expr.IsStatic() && pl.Pattern.IsStatic() +func (nod *PatternLikeExpr) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Pattern.IsStatic() } // ParamMarkerExpr expresion holds a place for another expression. @@ -537,11 +569,13 @@ type ParamMarkerExpr struct { } // Accept implements Node Accept interface. -func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pm); skipChildren { - return pm, ok +func (nod *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(pm) + nod = newNod.(*ParamMarkerExpr) + return v.Leave(nod) } // ParenthesesExpr is the parentheses expression. @@ -552,23 +586,25 @@ type ParenthesesExpr struct { } // Accept implements Node Accept interface. -func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (nod *ParenthesesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if p.Expr != nil { - node, ok := p.Expr.Accept(v) + nod = newNod.(*ParenthesesExpr) + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) if !ok { - return p, false + return nod, false } - p.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) } - return v.Leave(p) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (p *ParenthesesExpr) IsStatic() bool { - return p.Expr.IsStatic() +func (nod *ParenthesesExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // PositionExpr is the expression for order by and group by position. @@ -583,16 +619,18 @@ type PositionExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (p *PositionExpr) IsStatic() bool { +func (nod *PositionExpr) IsStatic() bool { return true } // Accept implements Node Accept interface. -func (p *PositionExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (nod *PositionExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(p) + nod = newNod.(*PositionExpr) + return v.Leave(nod) } // PatternRegexpExpr is the pattern expression for pattern match. @@ -607,26 +645,28 @@ type PatternRegexpExpr struct { } // Accept implements Node Accept interface. -func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (nod *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := p.Expr.Accept(v) + nod = newNod.(*PatternRegexpExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return p, false + return nod, false } - p.Expr = node.(ExprNode) - node, ok = p.Pattern.Accept(v) + nod.Expr = node.(ExprNode) + node, ok = nod.Pattern.Accept(v) if !ok { - return p, false + return nod, false } - p.Pattern = node.(ExprNode) - return v.Leave(p) + nod.Pattern = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (p *PatternRegexpExpr) IsStatic() bool { - return p.Expr.IsStatic() && p.Pattern.IsStatic() +func (nod *PatternRegexpExpr) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Pattern.IsStatic() } // RowExpr is the expression for row constructor. @@ -638,23 +678,25 @@ type RowExpr struct { } // Accept implements Node Accept interface. -func (r *RowExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(r); skipChildren { - return r, ok +func (nod *RowExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range r.Values { + nod = newNod.(*RowExpr) + for i, val := range nod.Values { node, ok := val.Accept(v) if !ok { - return r, false + return nod, false } - r.Values[i] = node.(ExprNode) + nod.Values[i] = node.(ExprNode) } - return v.Leave(r) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (r *RowExpr) IsStatic() bool { - for _, v := range r.Values { +func (nod *RowExpr) IsStatic() bool { + for _, v := range nod.Values { if !v.IsStatic() { return false } @@ -672,21 +714,23 @@ type UnaryOperationExpr struct { } // Accept implements Node Accept interface. -func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(u); skipChildren { - return u, ok +func (nod *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := u.V.Accept(v) + nod = newNod.(*UnaryOperationExpr) + node, ok := nod.V.Accept(v) if !ok { - return u, false + return nod, false } - u.V = node.(ExprNode) - return v.Leave(u) + nod.V = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (u *UnaryOperationExpr) IsStatic() bool { - return u.V.IsStatic() +func (nod *UnaryOperationExpr) IsStatic() bool { + return nod.V.IsStatic() } // ValuesExpr is the expression used in INSERT VALUES @@ -697,16 +741,18 @@ type ValuesExpr struct { } // Accept implements Node Accept interface. -func (va *ValuesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (nod *ValuesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Column.Accept(v) + nod = newNod.(*ValuesExpr) + node, ok := nod.Column.Accept(v) if !ok { - return va, false + return nod, false } - va.Column = node.(*ColumnName) - return v.Leave(va) + nod.Column = node.(*ColumnName) + return v.Leave(nod) } // VariableExpr is the expression for variable. @@ -721,9 +767,11 @@ type VariableExpr struct { } // Accept implements Node Accept interface. -func (va *VariableExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (nod *VariableExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(va) + nod = newNod.(*VariableExpr) + return v.Leave(nod) } diff --git a/ast/functions.go b/ast/functions.go index 3843478313..4e052b3ccf 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -42,28 +42,30 @@ type FuncCallExpr struct { } // Accept implements Node interface. -func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(c); skipChildren { - return c, ok +func (nod *FuncCallExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range c.Args { + nod = newNod.(*FuncCallExpr) + for i, val := range nod.Args { node, ok := val.Accept(v) if !ok { - return c, false + return nod, false } - c.Args[i] = node.(ExprNode) + nod.Args[i] = node.(ExprNode) } - return v.Leave(c) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (c *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(c.F)] +func (nod *FuncCallExpr) IsStatic() bool { + v := builtin.Funcs[strings.ToLower(nod.F)] if v.F == nil || !v.IsStatic { return false } - for _, v := range c.Args { + for _, v := range nod.Args { if !v.IsStatic() { return false } @@ -81,21 +83,23 @@ type FuncExtractExpr struct { } // Accept implements Node Accept interface. -func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ex); skipChildren { - return ex, ok +func (nod *FuncExtractExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ex.Date.Accept(v) + nod = newNod.(*FuncExtractExpr) + node, ok := nod.Date.Accept(v) if !ok { - return ex, false + return nod, false } - ex.Date = node.(ExprNode) - return v.Leave(ex) + nod.Date = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (ex *FuncExtractExpr) IsStatic() bool { - return ex.Date.IsStatic() +func (nod *FuncExtractExpr) IsStatic() bool { + return nod.Date.IsStatic() } // FuncConvertExpr provides a way to convert data between different character sets. @@ -109,21 +113,23 @@ type FuncConvertExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncConvertExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (nod *FuncConvertExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (nod *FuncConvertExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + nod = newNod.(*FuncConvertExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return f, false + return nod, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // CastFunctionType is the type for cast function. @@ -149,21 +155,23 @@ type FuncCastExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncCastExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (nod *FuncCastExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (nod *FuncCastExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + nod = newNod.(*FuncCastExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return f, false + return nod, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // FuncSubstringExpr returns the substring as specified. @@ -177,31 +185,33 @@ type FuncSubstringExpr struct { } // Accept implements Node Accept interface. -func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (nod *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := sf.StrExpr.Accept(v) + nod = newNod.(*FuncSubstringExpr) + node, ok := nod.StrExpr.Accept(v) if !ok { - return sf, false + return nod, false } - sf.StrExpr = node.(ExprNode) - node, ok = sf.Pos.Accept(v) + nod.StrExpr = node.(ExprNode) + node, ok = nod.Pos.Accept(v) if !ok { - return sf, false + return nod, false } - sf.Pos = node.(ExprNode) - node, ok = sf.Len.Accept(v) + nod.Pos = node.(ExprNode) + node, ok = nod.Len.Accept(v) if !ok { - return sf, false + return nod, false } - sf.Len = node.(ExprNode) - return v.Leave(sf) + nod.Len = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (sf *FuncSubstringExpr) IsStatic() bool { - return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic() +func (nod *FuncSubstringExpr) IsStatic() bool { + return nod.StrExpr.IsStatic() && nod.Pos.IsStatic() && nod.Len.IsStatic() } // FuncSubstringIndexExpr returns the substring as specified. @@ -215,26 +225,28 @@ type FuncSubstringIndexExpr struct { } // Accept implements Node Accept interface. -func (si *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(si); skipChildren { - return si, ok +func (nod *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := si.StrExpr.Accept(v) + nod = newNod.(*FuncSubstringIndexExpr) + node, ok := nod.StrExpr.Accept(v) if !ok { - return si, false + return nod, false } - si.StrExpr = node.(ExprNode) - node, ok = si.Delim.Accept(v) + nod.StrExpr = node.(ExprNode) + node, ok = nod.Delim.Accept(v) if !ok { - return si, false + return nod, false } - si.Delim = node.(ExprNode) - node, ok = si.Count.Accept(v) + nod.Delim = node.(ExprNode) + node, ok = nod.Count.Accept(v) if !ok { - return si, false + return nod, false } - si.Count = node.(ExprNode) - return v.Leave(si) + nod.Count = node.(ExprNode) + return v.Leave(nod) } // FuncLocateExpr returns the position of the first occurrence of substring. @@ -248,26 +260,28 @@ type FuncLocateExpr struct { } // Accept implements Node Accept interface. -func (le *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(le); skipChildren { - return le, ok +func (nod *FuncLocateExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := le.Str.Accept(v) + nod = newNod.(*FuncLocateExpr) + node, ok := nod.Str.Accept(v) if !ok { - return le, false + return nod, false } - le.Str = node.(ExprNode) - node, ok = le.SubStr.Accept(v) + nod.Str = node.(ExprNode) + node, ok = nod.SubStr.Accept(v) if !ok { - return le, false + return nod, false } - le.SubStr = node.(ExprNode) - node, ok = le.Pos.Accept(v) + nod.SubStr = node.(ExprNode) + node, ok = nod.Pos.Accept(v) if !ok { - return le, false + return nod, false } - le.Pos = node.(ExprNode) - return v.Leave(le) + nod.Pos = node.(ExprNode) + return v.Leave(nod) } // TrimDirectionType is the type for trim direction. @@ -295,26 +309,28 @@ type FuncTrimExpr struct { } // Accept implements Node Accept interface. -func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tf); skipChildren { - return tf, ok +func (nod *FuncTrimExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := tf.Str.Accept(v) + nod = newNod.(*FuncTrimExpr) + node, ok := nod.Str.Accept(v) if !ok { - return tf, false + return nod, false } - tf.Str = node.(ExprNode) - node, ok = tf.RemStr.Accept(v) + nod.Str = node.(ExprNode) + node, ok = nod.RemStr.Accept(v) if !ok { - return tf, false + return nod, false } - tf.RemStr = node.(ExprNode) - return v.Leave(tf) + nod.RemStr = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (tf *FuncTrimExpr) IsStatic() bool { - return tf.Str.IsStatic() && tf.RemStr.IsStatic() +func (nod *FuncTrimExpr) IsStatic() bool { + return nod.Str.IsStatic() && nod.RemStr.IsStatic() } // TypeStar is a special type for "*". diff --git a/ast/misc.go b/ast/misc.go index 2ad94e7755..d269e27c84 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -57,16 +57,18 @@ type ExplainStmt struct { } // Accept implements Node Accept interface. -func (es *ExplainStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (nod *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Stmt.Accept(v) + nod = newNod.(*ExplainStmt) + node, ok := nod.Stmt.Accept(v) if !ok { - return es, false + return nod, false } - es.Stmt = node.(DMLNode) - return v.Leave(es) + nod.Stmt = node.(DMLNode) + return v.Leave(nod) } // PrepareStmt is a statement to prepares a SQL statement which contains placeholders, @@ -83,16 +85,18 @@ type PrepareStmt struct { } // Accept implements Node Accept interface. -func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ps); skipChildren { - return ps, ok +func (nod *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ps.SQLVar.Accept(v) + nod = newNod.(*PrepareStmt) + node, ok := nod.SQLVar.Accept(v) if !ok { - return ps, false + return nod, false } - ps.SQLVar = node.(*VariableExpr) - return v.Leave(ps) + nod.SQLVar = node.(*VariableExpr) + return v.Leave(nod) } // DeallocateStmt is a statement to release PreparedStmt. @@ -105,11 +109,13 @@ type DeallocateStmt struct { } // Accept implements Node Accept interface. -func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ds); skipChildren { - return ds, ok +func (nod *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(ds) + nod = newNod.(*DeallocateStmt) + return v.Leave(nod) } // ExecuteStmt is a statement to execute PreparedStmt. @@ -123,18 +129,20 @@ type ExecuteStmt struct { } // Accept implements Node Accept interface. -func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (nod *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range es.UsingVars { + nod = newNod.(*ExecuteStmt) + for i, val := range nod.UsingVars { node, ok := val.Accept(v) if !ok { - return es, false + return nod, false } - es.UsingVars[i] = node.(ExprNode) + nod.UsingVars[i] = node.(ExprNode) } - return v.Leave(es) + return v.Leave(nod) } // ShowStmtType is the type for SHOW statement. @@ -173,39 +181,41 @@ type ShowStmt struct { } // Accept implements Node Accept interface. -func (ss *ShowStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ss); skipChildren { - return ss, ok +func (nod *ShowStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if ss.Table != nil { - node, ok := ss.Table.Accept(v) + nod = newNod.(*ShowStmt) + if nod.Table != nil { + node, ok := nod.Table.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Table = node.(*TableName) + nod.Table = node.(*TableName) } - if ss.Column != nil { - node, ok := ss.Column.Accept(v) + if nod.Column != nil { + node, ok := nod.Column.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Column = node.(*ColumnName) + nod.Column = node.(*ColumnName) } - if ss.Pattern != nil { - node, ok := ss.Pattern.Accept(v) + if nod.Pattern != nil { + node, ok := nod.Pattern.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Pattern = node.(*PatternLikeExpr) + nod.Pattern = node.(*PatternLikeExpr) } - if ss.Where != nil { - node, ok := ss.Where.Accept(v) + if nod.Where != nil { + node, ok := nod.Where.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - return v.Leave(ss) + return v.Leave(nod) } // BeginStmt is a statement to start a new transaction. @@ -215,11 +225,13 @@ type BeginStmt struct { } // Accept implements Node Accept interface. -func (bs *BeginStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(bs); skipChildren { - return bs, ok +func (nod *BeginStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(bs) + nod = newNod.(*BeginStmt) + return v.Leave(nod) } // CommitStmt is a statement to commit the current transaction. @@ -229,11 +241,13 @@ type CommitStmt struct { } // Accept implements Node Accept interface. -func (cs *CommitStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (nod *CommitStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cs) + nod = newNod.(*CommitStmt) + return v.Leave(nod) } // RollbackStmt is a statement to roll back the current transaction. @@ -243,11 +257,13 @@ type RollbackStmt struct { } // Accept implements Node Accept interface. -func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rs); skipChildren { - return rs, ok +func (nod *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(rs) + nod = newNod.(*RollbackStmt) + return v.Leave(nod) } // UseStmt is a statement to use the DBName database as the current database. @@ -259,11 +275,13 @@ type UseStmt struct { } // Accept implements Node Accept interface. -func (us *UseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(us); skipChildren { - return us, ok +func (nod *UseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(us) + nod = newNod.(*UseStmt) + return v.Leave(nod) } // VariableAssignment is a variable assignment struct. @@ -276,16 +294,18 @@ type VariableAssignment struct { } // Accept implements Node interface. -func (va *VariableAssignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (nod *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Value.Accept(v) + nod = newNod.(*VariableAssignment) + node, ok := nod.Value.Accept(v) if !ok { - return va, false + return nod, false } - va.Value = node.(ExprNode) - return v.Leave(va) + nod.Value = node.(ExprNode) + return v.Leave(nod) } // SetStmt is the statement to set variables. @@ -296,18 +316,20 @@ type SetStmt struct { } // Accept implements Node Accept interface. -func (set *SetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (nod *SetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range set.Variables { + nod = newNod.(*SetStmt) + for i, val := range nod.Variables { node, ok := val.Accept(v) if !ok { - return set, false + return nod, false } - set.Variables[i] = node.(*VariableAssignment) + nod.Variables[i] = node.(*VariableAssignment) } - return v.Leave(set) + return v.Leave(nod) } // SetCharsetStmt is a statement to assign values to character and collation variables. @@ -320,11 +342,13 @@ type SetCharsetStmt struct { } // Accept implements Node Accept interface. -func (set *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (nod *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + nod = newNod.(*SetCharsetStmt) + return v.Leave(nod) } // SetPwdStmt is a statement to assign a password to user account. @@ -337,11 +361,13 @@ type SetPwdStmt struct { } // Accept implements Node Accept interface. -func (set *SetPwdStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (nod *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + nod = newNod.(*SetPwdStmt) + return v.Leave(nod) } // UserSpec is used for parsing create user statement. @@ -367,16 +393,18 @@ type DoStmt struct { } // Accept implements Node Accept interface. -func (do *DoStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(do); skipChildren { - return do, ok +func (nod *DoStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range do.Exprs { + nod = newNod.(*DoStmt) + for i, val := range nod.Exprs { node, ok := val.Accept(v) if !ok { - return do, false + return nod, false } - do.Exprs[i] = node.(ExprNode) + nod.Exprs[i] = node.(ExprNode) } - return v.Leave(do) + return v.Leave(nod) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 24cde087e6..97790fd0d5 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -394,7 +394,7 @@ import ( FieldAsName "Field alias name" FieldAsNameOpt "Field alias name opt" FieldList "field expression list" - FromClause "From clause" + TableRefsClause "Table references clause" Function "function expr" FunctionCallAgg "Function call on aggregate data" FunctionCallConflict "Function call with reserved keyword as function name" @@ -404,7 +404,6 @@ import ( FuncDatetimePrec "Function datetime precision" GlobalScope "The scope of variable" GroupByClause "GROUP BY clause" - GroupByList "GROUP BY list" HashString "Hashed string" HavingClause "HAVING clause" IfExists "If Exists" @@ -437,9 +436,9 @@ import ( OptInteger "Optional Integer keyword" Order "ORDER BY clause optional collation specification" OrderBy "ORDER BY clause" - OrderByItem "ORDER BY list item" + ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - OrderByList "ORDER BY list" + ByList "BY list" OuterOpt "optional OUTER clause" QuickOptional "QUICK or empty" PasswordOpt "Password option" @@ -452,7 +451,6 @@ import ( ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" RollbackStmt "ROLLBACK statement" - SelectBasic "Basic SELECT statement without parentheses and union" SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," SelectStmt "SELECT statement" SelectStmtCalcFoundRows "SELECT statement optional SQL_CALC_FOUND_ROWS" @@ -491,9 +489,8 @@ import ( TrimDirection "Trim string direction" TruncateTableStmt "TRANSACTION TABLE statement" UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionClause "Union select" + UnionStmt "Union select state ment" UnionClauseList "Union select clause list" - UnionClauseP "Union (select)" UnionClausePList "Union (select) clause list" UpdateStmt "UPDATE statement" Username "Username" @@ -1141,12 +1138,13 @@ DeleteFromStmt: "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause { // Single Table + join := &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil} x := &ast.DeleteStmt{ - TableRefs: &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil}, + TableRefs: &ast.TableRefsClause{TableRefs: join}, LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - Order: $8.([]*ast.OrderByItem), + Order: $8.([]*ast.ByItem), } if $7 != nil { x.Where = $7.(ast.ExprNode) @@ -1171,7 +1169,7 @@ DeleteFromStmt: MultiTable: true, BeforeFrom: true, Tables: $5.([]*ast.TableName), - TableRefs: $7.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, } if $8 != nil { x.Where = $8.(ast.ExprNode) @@ -1190,7 +1188,7 @@ DeleteFromStmt: Ignore: $4.(bool), MultiTable: true, Tables: $6.([]*ast.TableName), - TableRefs: $8.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, } if $9 != nil { x.Where = $9.(ast.ExprNode) @@ -1493,13 +1491,13 @@ Field: } | Identifier '.' '*' { - tn := &ast.TableName{Name:model.NewCIStr($1.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} + wildCard := &ast.WildCardField{Table: model.NewCIStr($1.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Identifier '.' Identifier '.' '*' { - tn := &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} + wildCard := &ast.WildCardField{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Expression FieldAsNameOpt { @@ -1545,19 +1543,9 @@ FieldList: } GroupByClause: - "GROUP" "BY" GroupByList + "GROUP" "BY" ByList { - $$ = $3.([]ast.ExprNode) - } - -GroupByList: - Expression - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| GroupByList ',' Expression - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) + $$ = &ast.GroupByClause{Items: $3.([]*ast.ByItem)} } HavingClause: @@ -1566,7 +1554,7 @@ HavingClause: } | "HAVING" Expression { - $$ = $2.(ast.ExprNode) + $$ = &ast.HavingClause{Expr: $2.(ast.ExprNode)} } IfExists: @@ -1643,7 +1631,9 @@ InsertIntoStmt: { x := $6.(*ast.InsertStmt) x.Priority = $2.(int) - x.Table = $5.(*ast.TableName) + // Wraps many layers here so that it can be processed the same way as select statement. + ts := &ast.TableSource{Source: $5.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} if $7 != nil { x.OnDuplicate = $7.([]*ast.Assignment) } @@ -1804,25 +1794,25 @@ Operand: } OrderBy: - "ORDER" "BY" OrderByList + "ORDER" "BY" ByList { - $$ = $3.([]*ast.OrderByItem) + $$ = &ast.OrderByClause{Items: $3.([]*ast.ByItem)} } -OrderByList: - OrderByItem +ByList: + ByItem { - $$ = []*ast.OrderByItem{$1.(*ast.OrderByItem)} + $$ = []*ast.ByItem{$1.(*ast.ByItem)} } -| OrderByList ',' OrderByItem +| ByList ',' ByItem { - $$ = append($1.([]*ast.OrderByItem), $3.(*ast.OrderByItem)) + $$ = append($1.([]*ast.ByItem), $3.(*ast.ByItem)) } -OrderByItem: +ByItem: Expression Order { - $$ = &ast.OrderByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} + $$ = &ast.ByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} } Order: @@ -2556,40 +2546,23 @@ RollbackStmt: } SelectStmt: - SelectBasic -| SelectBasic UnionClauseList - { - st := $1.(*ast.SelectStmt) - st.Unions = $2.([]*ast.UnionClause) - $$ = st - } -| SubSelect UnionClausePList OrderByOptional SelectStmtLimit - { - st := $1.(*ast.SubqueryExpr).Query - st.Unions = $2.([]*ast.UnionClause) - st.UnionOrderBy = $3.([]*ast.OrderByItem) - st.UnionLimit = $4.(*ast.Limit) - $$ = st - } - -SelectBasic: "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt { $$ = &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), + Fields: $3.(*ast.FieldList), From: nil, LockTp: $6.(ast.SelectLockType), } } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional + TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional SelectStmtLimit SelectLockOpt { st := &ast.SelectStmt{ Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), - From: $5.(*ast.Join), + Fields: $3.(*ast.FieldList), + From: $5.(*ast.TableRefsClause), LockTp: $11.(ast.SelectLockType), } @@ -2598,15 +2571,15 @@ SelectBasic: } if $7 != nil { - st.GroupBy = $7.([]ast.ExprNode) + st.GroupBy = $7.(*ast.GroupByClause) } if $8 != nil { - st.Having = $8.(ast.ExprNode) + st.Having = $8.(*ast.HavingClause) } if $9 != nil { - st.OrderBy = $9.([]*ast.OrderByItem) + st.OrderBy = $9.(*ast.OrderByClause) } if $10 != nil { @@ -2621,17 +2594,17 @@ FromDual: | "FROM" "DUAL" -FromClause: +TableRefsClause: TableRefs { - $$ = $1 + $$ = &ast.TableRefsClause{TableRefs: $1.(*ast.Join)} } TableRefs: EscapedTableRef { if j, ok := $1.(*ast.Join); ok { - // if $1 is JoinRset, use it directly + // if $1 is Join, use it directly $$ = j } else { $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} @@ -2708,11 +2681,13 @@ JoinTable: } | TableRef CrossOpt TableRef "ON" Expression { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: $5.(ast.ExprNode)} + on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} } | TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: $7.(ast.ExprNode)} + { + on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} } /* Support Using */ @@ -2799,7 +2774,7 @@ SelectStmtCalcFoundRows: SelectStmtFieldList: FieldList { - $$ = $1 + $$ = &ast.FieldList{Fields: $1.([]*ast.SelectField)} } SelectStmtGroup: @@ -2836,36 +2811,61 @@ SelectLockOpt: } // See: https://dev.mysql.com/doc/refman/5.7/en/union.html -UnionClause: - "UNION" UnionOpt SelectBasic +UnionStmt: + UnionClauseList { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SelectStmt)} + $$ = $1.(*ast.UnionStmt) } - -UnionClauseP: - "UNION" UnionOpt SubSelect +| UnionClausePList OrderByOptional SelectStmtLimit { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SubqueryExpr).Query} + union := $1.(*ast.UnionStmt) + if $2 != nil { + // push union order by into select statements. + orderBy := $2.(*ast.OrderByClause) + cloner := &ast.Cloner{} + for _, s := range union.Selects { + node, _ := orderBy.Accept(cloner) + s.OrderBy = node.(*ast.OrderByClause) + } + } + if $3 != nil { + union.Limit = $3.(*ast.Limit) + } + $$ = union } UnionClauseList: - UnionClause + SelectStmt "UNION" UnionOpt SelectStmt { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt), $4.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Distinct: $3.(bool), + Selects: selects, + } } -| UnionClauseList UnionClause +| UnionClauseList "UNION" UnionOpt SelectStmt { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union } UnionClausePList: - UnionClauseP + '(' SelectStmt ')' "UNION" UnionOpt '(' SelectStmt ')' { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} + selects := []*ast.SelectStmt{$2.(*ast.SelectStmt), $7.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Distinct: $5.(bool), + Selects: selects, + } } -| UnionClausePList UnionClauseP +| UnionClausePList "UNION" UnionOpt '(' SelectStmt ')' { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) + $$ = union } UnionOpt: @@ -3071,7 +3071,7 @@ ShowStmt: Where: $3.(ast.ExprNode), } } -| "SHOW" "CREATE" "TABLE" TableName +| "SHOW" "CREATE" "TABLE" TableName { $$ = &ast.ShowStmt{ Tp: ast.ShowCreateTable, @@ -3794,17 +3794,17 @@ UpdateStmt: "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { // Single-table syntax - r := &ast.Join{Left: $4.(ast.ResultSetNode), Right: nil} + join := &ast.Join{Left: &ast.TableSource{Source:$4.(ast.ResultSetNode)}, Right: nil} st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: r, + TableRefs: &ast.TableRefsClause{TableRefs: join}, List: $6.([]*ast.Assignment), } if $7 != nil { st.Where = $7.(ast.ExprNode) } if $8 != nil { - st.Order = $8.([]*ast.OrderByItem) + st.Order = $8.([]*ast.ByItem) } if $9 != nil { st.Limit = $9.(*ast.Limit) @@ -3819,7 +3819,7 @@ UpdateStmt: // Multiple-table syntax st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: $4.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, List: $6.([]*ast.Assignment), MultipleTable: true, } diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go index 6b32ea65fc..1e2798c28c 100644 --- a/ast/parser/parser_test.go +++ b/ast/parser/parser_test.go @@ -56,8 +56,8 @@ func (s *testParserSuite) TestSimple(c *C) { st := l.Stmts()[0] ss, ok := st.(*ast.SelectStmt) c.Assert(ok, IsTrue) - c.Assert(len(ss.Fields), Equals, 1) - cv, ok := ss.Fields[0].Expr.(*ast.FuncCastExpr) + c.Assert(len(ss.Fields.Fields), Equals, 1) + cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) c.Assert(ok, IsTrue) c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) diff --git a/ast/parser/yy_parser.go b/ast/parser/yy_parser.go new file mode 100644 index 0000000000..dfc3264796 --- /dev/null +++ b/ast/parser/yy_parser.go @@ -0,0 +1,19 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +// YYParse is an wrapper of `yyParse` to make it exported. +func YYParse(yylex yyLexer) int { + return yyParse(yylex) +} diff --git a/optimizer/infobinder.go b/optimizer/infobinder.go new file mode 100644 index 0000000000..0ad59bceb6 --- /dev/null +++ b/optimizer/infobinder.go @@ -0,0 +1,441 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" +) + +// InfoBinder binds schema information for table name and column name and set result fields +// for ResetSetNode. +// We need to know which table a table name refers to, which column a column name refers to. +// +// In general, a reference can only refer to information that are available for it. +// So children elements are visited in the order that previous elements make information +// available for following elements. +// +// During visiting, information are collected and stored in binderContext. +// When we enter a sub query, a new binderContext is pushed to the contextStack, so sub query +// information can overwrite outer query information. When we look up for a column reference, +// we look up from top to bottom in the contextStack. +type InfoBinder struct { + Info infoschema.InfoSchema + DefaultSchema model.CIStr + Err error + + contextStack []*binderContext +} + +// binderContext stores information that table name and column name +// can be bind to. +type binderContext struct { + /* For Select Statement. */ + // table map to lookup and check table name conflict. + tableMap map[string]int + // tableSources collected in from clause. + tables []*ast.TableSource + // result fields collected in select field list. + fieldList []*ast.ResultField + // result fields collected in group by clause. + groupBy []*ast.ResultField + + // The join node stack is used by on condition to find out + // available tables to reference. On condition can only + // refer to tables involved in current join. + joinNodeStack []*ast.Join + + // When visiting TableRefs, tables in this context are not available + // because it is being collected. + inTableRefs bool + // When visiting on conditon only tables in current join node are available. + inOnCondition bool + // When visiting field list, fieldList in this context are not available. + inFieldList bool + // When visiting group by, groupBy fields are not available. + inGroupBy bool + // When visiting having, only fieldList and groupBy fields are available. + inHaving bool +} + +// currentContext gets the current binder context. +func (sb *InfoBinder) currentContext() *binderContext { + stackLen := len(sb.contextStack) + if stackLen == 0 { + return nil + } + return sb.contextStack[stackLen-1] +} + +// pushContext is called when we enter a statement. +func (sb *InfoBinder) pushContext() { + sb.contextStack = append(sb.contextStack, &binderContext{ + tableMap: map[string]int{}, + }) +} + +// popContext is called when we leave a statement. +func (sb *InfoBinder) popContext() { + sb.contextStack = sb.contextStack[:len(sb.contextStack)-1] +} + +// pushJoin is called when we enter a join node. +func (sb *InfoBinder) pushJoin(j *ast.Join) { + ctx := sb.currentContext() + ctx.joinNodeStack = append(ctx.joinNodeStack, j) +} + +// popJoin is called when we leave a join node. +func (sb *InfoBinder) popJoin() { + ctx := sb.currentContext() + ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1] +} + +// Enter implements ast.Visitor interface. +func (sb *InfoBinder) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { + switch v := inNode.(type) { + case *ast.SelectStmt: + sb.pushContext() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = true + case *ast.Join: + sb.pushJoin(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = true + case *ast.FieldList: + sb.currentContext().inFieldList = true + case *ast.GroupByClause: + sb.currentContext().inGroupBy = true + case *ast.HavingClause: + sb.currentContext().inHaving = true + case *ast.InsertStmt: + sb.pushContext() + case *ast.DeleteStmt: + sb.pushContext() + case *ast.UpdateStmt: + sb.pushContext() + } + return inNode, false +} + +// Leave implements ast.Visitor interface. +func (sb *InfoBinder) Leave(inNode ast.Node) (node ast.Node, ok bool) { + switch v := inNode.(type) { + case *ast.TableName: + sb.handleTableName(v) + case *ast.ColumnName: + sb.handleColumnName(v) + case *ast.TableSource: + sb.handleTableSource(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = false + case *ast.Join: + sb.handleJoin(v) + sb.popJoin() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = false + case *ast.FieldList: + sb.handleFieldList(v) + sb.currentContext().inFieldList = false + case *ast.GroupByClause: + sb.currentContext().inGroupBy = false + case *ast.HavingClause: + sb.currentContext().inHaving = false + case *ast.SelectStmt: + v.SetResultFields(sb.currentContext().fieldList) + sb.popContext() + case *ast.InsertStmt: + sb.popContext() + case *ast.DeleteStmt: + sb.popContext() + case *ast.UpdateStmt: + sb.popContext() + } + return inNode, sb.Err == nil +} + +// handleTableName looks up and bind the schema information for table name +// and set result fields for table name. +func (sb *InfoBinder) handleTableName(tn *ast.TableName) { + if tn.Schema.L == "" { + tn.Schema = sb.DefaultSchema + } + table, err := sb.Info.TableByName(tn.Schema, tn.Name) + if err != nil { + sb.Err = err + return + } + tn.TableInfo = table.Meta() + dbInfo, _ := sb.Info.SchemaByName(tn.Schema) + tn.DBInfo = dbInfo + + rfs := make([]*ast.ResultField, len(tn.TableInfo.Columns)) + for i, v := range tn.TableInfo.Columns { + rfs[i] = &ast.ResultField{ + Column: v, + Table: tn.TableInfo, + DBName: tn.Schema, + } + } + tn.SetResultFields(rfs) + return +} + +// handleTableSources checks name duplication +// and puts the table source in current binderContext. +func (sb *InfoBinder) handleTableSource(ts *ast.TableSource) { + for _, v := range ts.GetResultFields() { + v.TableAsName = ts.AsName + } + var name string + if ts.AsName.L != "" { + name = ts.AsName.L + } else { + tableName := ts.Source.(*ast.TableName) + name = sb.tableUniqueName(tableName.Schema, tableName.Name) + } + ctx := sb.currentContext() + if _, ok := ctx.tableMap[name]; ok { + sb.Err = errors.Errorf("duplicated table/alias name %s", name) + return + } + ctx.tableMap[name] = len(ctx.tables) + ctx.tables = append(ctx.tables, ts) + return +} + +// handleJoin sets result fields for join. +func (sb *InfoBinder) handleJoin(j *ast.Join) { + if j.Right == nil { + j.SetResultFields(j.Left.GetResultFields()) + return + } + leftLen := len(j.Left.GetResultFields()) + rightLen := len(j.Right.GetResultFields()) + rfs := make([]*ast.ResultField, leftLen+rightLen) + copy(rfs, j.Left.GetResultFields()) + copy(rfs[leftLen:], j.Right.GetResultFields()) + j.SetResultFields(rfs) +} + +// handleColumnName looks up and binds schema information to +// the column name. +func (sb *InfoBinder) handleColumnName(cn *ast.ColumnName) { + ctx := sb.currentContext() + if ctx.inOnCondition { + // In on condition, only tables within current join is available. + sb.bindColumnNameInOnCondition(cn) + return + } + + // Try to bind the column name form top to bottom in the context stack. + for i := len(sb.contextStack) - 1; i >= 0; i-- { + if sb.bindColumnNameInContext(sb.contextStack[i], cn) { + // Column is already bound or encountered an error. + return + } + } + sb.Err = errors.Errorf("Unknown column %s", cn.Name.L) +} + +// bindColumnNameInContext looks up and binds schema information for a column with the ctx. +func (sb *InfoBinder) bindColumnNameInContext(ctx *binderContext, cn *ast.ColumnName) (done bool) { + if cn.Table.L == "" { + // If qualified table name is not specified in column name, the column name may be ambiguous, + // We need to iterate over all tables and + } + + if ctx.inTableRefs { + // In TableRefsClause, column reference only in join on condition which is handled before. + return false + } + if ctx.inFieldList { + // only bind column using tables. + return sb.bindColumnInTableSources(cn, ctx.tables) + } + if ctx.inGroupBy { + // field list first, then tables. + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + return sb.bindColumnInTableSources(cn, ctx.tables) + } + // column name in other places can be looked up in the same order. + if sb.bindColumnInResultFields(cn, ctx.groupBy) { + return true + } + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + + // tables is not available for having clause. + if !ctx.inHaving { + return sb.bindColumnInTableSources(cn, ctx.tables) + } + return false +} + +// bindColumnNameInOnCondition looks up for column name in current join, and +// binds the schema information. +func (sb *InfoBinder) bindColumnNameInOnCondition(cn *ast.ColumnName) { + ctx := sb.currentContext() + join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1] + tableSources := appendTableSources(nil, join) + if !sb.bindColumnInTableSources(cn, tableSources) { + sb.Err = errors.Errorf("unkown column name %s", cn.Name.O) + } +} + +func (sb *InfoBinder) bindColumnInTableSources(cn *ast.ColumnName, tableSources []*ast.TableSource) (done bool) { + var matchedResultField *ast.ResultField + if cn.Table.L != "" { + var matchedTable ast.ResultSetNode + for _, ts := range tableSources { + if cn.Table.L == ts.AsName.L { + // different table name. + matchedTable = ts + break + } + if tn, ok := ts.Source.(*ast.TableName); ok { + if cn.Table.L == tn.Name.L { + matchedTable = ts + } + } + } + if matchedTable != nil { + resultFields := matchedTable.GetResultFields() + for _, rf := range resultFields { + if rf.ColumnAsName.L == cn.Name.L || rf.Column.Name.L == cn.Name.L { + // bind column. + matchedResultField = rf + break + } + } + } + } else { + for _, ts := range tableSources { + rfs := ts.GetResultFields() + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return true + } + matchedResultField = rf + } + } + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +func (sb *InfoBinder) bindColumnInResultFields(cn *ast.ColumnName, rfs []*ast.ResultField) bool { + var matchedResultField *ast.ResultField + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return false + } + matchedResultField = rf + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +// handleFieldList expands wild card field and set fieldList in current context. +func (sb *InfoBinder) handleFieldList(fieldList *ast.FieldList) { + var resultFields []*ast.ResultField + for _, v := range fieldList.Fields { + resultFields = append(resultFields, sb.createResultFields(v)...) + } + sb.currentContext().fieldList = resultFields +} + +// createResultFields creates result field list for a single select field. +func (sb *InfoBinder) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) { + ctx := sb.currentContext() + if field.WildCard != nil { + if len(ctx.tables) == 0 { + sb.Err = errors.Errorf("No table used.") + return + } + if field.WildCard.Table.L == "" { + for _, v := range ctx.tables { + rfs = append(rfs, v.GetResultFields()...) + } + } else { + name := sb.tableUniqueName(field.WildCard.Schema, field.WildCard.Table) + tableIdx, ok := ctx.tableMap[name] + if !ok { + sb.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O) + } + rfs = ctx.tables[tableIdx].GetResultFields() + } + return + } + // The column is visited before so it must has been bound already. + rf := &ast.ResultField{ColumnAsName: field.AsName} + switch v := field.Expr.(type) { + case *ast.ColumnNameExpr: + rf.Column = v.Name.ColumnInfo + rf.Table = v.Name.TableInfo + rf.DBName = v.Name.Schema + default: + if field.AsName.L == "" { + rf.ColumnAsName.L = field.Expr.Text() + rf.ColumnAsName.O = rf.ColumnAsName.L + } + } + rfs = append(rfs, rf) + return +} + +func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) { + switch v := resultSetNode.(type) { + case *ast.TableSource: + out = append(in, v) + case *ast.Join: + out = appendTableSources(in, v.Left) + if v.Right != nil { + out = appendTableSources(out, v.Right) + } + } + return +} + +func (sb *InfoBinder) tableUniqueName(schema, table model.CIStr) string { + if schema.L != "" && schema.L != sb.DefaultSchema.L { + return schema.L + "." + table.L + } + return table.L +} diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go new file mode 100644 index 0000000000..f2109e0c17 --- /dev/null +++ b/optimizer/infobinder_test.go @@ -0,0 +1,69 @@ +package optimizer_test + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/testkit" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testInfoBinderSuite{}) + +type testInfoBinderSuite struct { +} + +type binderVerifier struct { + c *C +} + +func (bv *binderVerifier) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (bv *binderVerifier) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ColumnName: + bv.c.Assert(v.ColumnInfo, NotNil) + case *ast.TableName: + bv.c.Assert(v.TableInfo, NotNil) + } + return in, true +} + +func (ts *testInfoBinderSuite) TestInfoBinder(c *C) { + store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory) + c.Assert(err, IsNil) + defer store.Close() + testKit := testkit.NewTestKit(c, store) + testKit.MustExec("use test") + testKit.MustExec("create table t (c1 int, c2 int)") + domain := sessionctx.GetDomain(testKit.Se.(context.Context)) + + src := "SELECT c1 from t" + l := parser.NewLexer(src) + c.Assert(parser.YYParse(l), Equals, 0) + stmts := l.Stmts() + c.Assert(len(stmts), Equals, 1) + v := &optimizer.InfoBinder{ + Info: domain.InfoSchema(), + DefaultSchema: model.NewCIStr("test"), + } + selectStmt := stmts[0].(*ast.SelectStmt) + selectStmt.Accept(v) + + verifier := &binderVerifier{ + c: c, + } + selectStmt.Accept(verifier) +} From 0147eedd2c8822a6cc65c266f2098dbceb0f92a5 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 27 Oct 2015 16:05:46 +0800 Subject: [PATCH 02/25] optimiser: implement evaluator. --- ast/ast.go | 4 + ast/base.go | 11 + ast/expressions.go | 2 - ast/functions.go | 29 +- ast/parser/parser.y | 28 +- optimizer/aggregator.go | 13 + optimizer/evaluator.go | 657 ++++++++++++++++++++++++++++++++++++++ optimizer/optimizer.go | 25 +- optimizer/typecomputer.go | 16 + optimizer/validator.go | 16 + 10 files changed, 781 insertions(+), 20 deletions(-) create mode 100644 optimizer/aggregator.go create mode 100644 optimizer/evaluator.go create mode 100644 optimizer/typecomputer.go create mode 100644 optimizer/validator.go diff --git a/ast/ast.go b/ast/ast.go index dabc4fa912..fc505c8362 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -44,6 +44,10 @@ type ExprNode interface { SetType(tp *types.FieldType) // GetType gets the evaluation type of the expression. GetType() *types.FieldType + // SetValue sets value to the expression. + SetValue(val interface{}) + // GetValue gets value of the expression. + GetValue() interface{} } // FuncNode represents function call expression node. diff --git a/ast/base.go b/ast/base.go index 50e9ef030c..2b25bb7f96 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,6 +62,7 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node + val interface{} tp *types.FieldType } @@ -80,6 +81,16 @@ func (en *exprNode) GetType() *types.FieldType { return en.tp } +// SetValue implements Expression interface. +func (en *exprNode) SetValue(val interface{}) { + en.val = val +} + +// GetValue implements Expression interface. +func (en *exprNode) GetValue() interface{} { + return en.val +} + type funcNode struct { exprNode } diff --git a/ast/expressions.go b/ast/expressions.go index c99611ae13..f06109d312 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -47,8 +47,6 @@ var ( // ValueExpr is the simple value expression. type ValueExpr struct { exprNode - // Val is the literal value. - Val interface{} } // IsStatic implements ExprNode interface. diff --git a/ast/functions.go b/ast/functions.go index 4e052b3ccf..f3efdbae41 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -36,9 +36,6 @@ type FuncCallExpr struct { F string // Args is the function args. Args []ExprNode - // Distinct only affetcts sum, avg, count, group_concat, - // so we can ignore it in other functions - Distinct bool } // Accept implements Node interface. @@ -335,3 +332,29 @@ func (nod *FuncTrimExpr) IsStatic() bool { // TypeStar is a special type for "*". type TypeStar string + +type AggregateFuncExpr struct { + funcNode + // F is the function name. + F string + // Args is the function args. + Args []ExprNode + Distinct bool +} + +// Accept implements Node Accept interface. +func (nod *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*AggregateFuncExpr) + for i, val := range nod.Args { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Args[i] = node.(ExprNode) + } + return v.Leave(nod) +} \ No newline at end of file diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 97790fd0d5..9397163269 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1888,7 +1888,7 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: false} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { @@ -1918,18 +1918,14 @@ DistinctOpt: } FunctionCallKeyword: - "AVG" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "CAST" '(' Expression "AS" CastType ')' + "CAST" '(' Expression "AS" CastType ')' { /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ $$ = &ast.FuncCastExpr{ Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), FunctionType: ast.CastFunction, - } + } } | "CASE" ExpressionOpt WhenClauseList ElseOpt "END" { @@ -2211,30 +2207,34 @@ TrimDirection: } FunctionCallAgg: - "COUNT" '(' DistinctOpt ExpressionList ')' + "AVG" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} + } +| "COUNT" '(' DistinctOpt ExpressionList ')' + { + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt '*' ')' { args := []ast.ExprNode{&ast.ValueExpr{Val: ast.TypeStar("*")} } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "MAX" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "MIN" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "SUM" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } FuncDatetimePrec: diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go new file mode 100644 index 0000000000..95b9c7cb80 --- /dev/null +++ b/optimizer/aggregator.go @@ -0,0 +1,13 @@ +package optimizer + +// Aggregator is the interface to +// compute aggregate function result. +type Aggregator interface { + // Input add a input value to aggregator. + // The input values are accumulated in the aggregator. + Input(in []interface{}) error + // Output use input values to compute the aggregated result. + Output() interface{} + // Clear clears the input values. + Clear() +} diff --git a/optimizer/evaluator.go b/optimizer/evaluator.go new file mode 100644 index 0000000000..52dab1d6f2 --- /dev/null +++ b/optimizer/evaluator.go @@ -0,0 +1,657 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/plan" + "github.com/pingcap/tidb/util/types" +) + +type Evaluator struct { + // columnMap is the map from ColumnName to the position of the rowStack. + // It is used to find the value of the column. + columnMap map[*ast.ColumnName]position + // rowStack is the current row values while scanning. + // It should be updated after scaned a new row. + rowStack []*plan.Row + + // the map from AggregateFuncExpr to aggregator index. + aggregateMap map[*ast.AggregateFuncExpr]int + + // aggregators for the current row + // only outer query aggregate functions are handled. + aggregators []Aggregator + + // when aggregation phase is done, the input is + aggregateDone bool + + err error +} + +type position struct { + stackOffset int + fieldList bool + columnOffset int +} + +func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + ok = true + case *ast.AggregateFuncExpr: + ok = e.handleAggregateFunc(v) + case *ast.BetweenExpr: + ok = e.handleBetween(v) + case *ast.BinaryOperationExpr: + ok = e.handleBinaryOperation(v) + } + out = in + return +} + +func (e *Evaluator) handleAggregateFunc(v *ast.AggregateFuncExpr) bool { + idx := e.aggregateMap[v] + aggr := e.aggregators[idx] + if e.aggregateDone { + v.SetValue(aggr.Output()) + return true + } + // TODO: currently only single argument aggregate functions are supported. + e.err = aggr.Input(v.Args[0].GetValue()) + return e.err == nil +} + +func checkAllOneColumn(exprs ...ast.ExprNode) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *ast.RowExpr: + return false + case *ast.SubqueryExpr: + if len(v.Query.Fields.Fields) != 1 { + return false + } + } + } + return true +} + +func (e *Evaluator) handleBetween(v *ast.BetweenExpr) bool { + if !checkAllOneColumn(v.Expr, v.Left, v.Right) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + + var l, r ast.ExprNode + op := opcode.AndAnd + + if v.Not { + // v < lv || v > rv + op = opcode.OrOr + l = &ast.BinaryOperationExpr{opcode.LT, v.Expr, v.Left} + r = &ast.BinaryOperationExpr{opcode.GT, v.Expr, v.Right} + } else { + // v >= lv && v <= rv + l = &ast.BinaryOperationExpr{opcode.GE, v.Expr, v.Left} + r = &ast.BinaryOperationExpr{opcode.LE, v.Expr, v.Right} + } + + ret := &ast.BinaryOperationExpr{op, l, r} + ret.Accept(e) + return e.err == nil +} + +const ( + zeroI64 int64 = 0 + oneI64 int64 = 1 +) + +func (e *Evaluator) handleBinaryOperation(o *ast.BinaryOperationExpr) bool { + // all operands must have same column. + if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { + return nil, false + } + + // row constructor only supports comparison operation. + switch o.Op { + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + default: + if !checkAllOneColumn(o.L) { + return nil, errors.Errorf("Operand should contain 1 column(s)") + } + } + + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + + switch o.Op { + case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: + return e.handleLogicOperation(o) + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + return e.handleComparisonOp(o) + case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: + // TODO: MySQL doesn't support and not, we should remove it later. + return e.handleBitOp(o) + case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: + return e.handleArithmeticOp(o) + default: + panic("should never happen") + } +} + +func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + var boolVal bool + switch o.Op { + case opcode.AndAnd: + boolVal = leftVal != zeroI64 && rightVal != zeroI64 + case opcode.OrOr: + boolVal = leftVal != zeroI64 || rightVal != zeroI64 + case opcode.LogicXor: + boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) + default: + panic("should never happen") + } + if boolVal { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + if types.IsNil(a) || types.IsNil(b) { + // for <=>, if a and b are both nil, return true. + // if a or b is nil, return false. + if o.Op == opcode.NullEQ { + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + } else { + o.SetValue(nil) + } + return true + } + + n, err := types.Compare(a, b) + if err != nil { + e.err = err + return false + } + + r, err := getCompResult(o.Op, n) + if err != nil { + e.err = err + return false + } + if r { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + case opcode.NullEQ: + return value == 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + +func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + + if types.IsNil(a) || types.IsNil(b) { + return nil, nil + } + + x, err := types.ToInt64(a) + if err != nil { + e.err = err + return false + } + + y, err := types.ToInt64(b) + if err != nil { + e.err = err + return false + } + + // use a int64 for bit operator, return uint64 + switch o.Op { + case opcode.And: + o.SetValue(uint64(x & y)) + case opcode.Or: + o.SetValue(uint64(x | y)) + case opcode.Xor: + o.SetValue(uint64(x ^ y)) + case opcode.RightShift: + o.SetValue(uint64(x) >> uint64(y)) + case opcode.LeftShift: + o.SetValue(uint64(x) << uint64(y)) + default: + e.err = errors.Errorf("invalid op %v in bit operation", o.Op) + return false + } + return true +} + +func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { + a, err := coerceArithmetic(o.L.GetValue()) + if err != nil { + e.err = err + return false + } + + b, err := coerceArithmetic(o.R.GetValue()) + if err != nil { + e.err = err + return false + } + a, b = types.Coerce(a, b) + + if a == nil || b == nil { + // TODO: for <=>, if a and b are both nil, return true + o.SetValue(nil) + return true + } + + // TODO: support logic division DIV + var result interface{} + switch o.Op { + case opcode.Plus: + result, e.err = computePlus(a, b) + case opcode.Minus: + result, e.err = computeMinus(a, b) + case opcode.Mul: + result, e.err = computeMul(a, b) + case opcode.Div: + result, e.err = computeDiv(a, b) + case opcode.Mod: + result, e.err = computeMod(a, b) + case opcode.IntDiv: + result, e.err = computeIntDiv(a, b) + default: + e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) + return false + } + o.SetValue(result) + return e.err == nil +} + +func computePlus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.AddInt64(x, y) + case uint64: + return types.AddInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.AddInteger(x, y) + case uint64: + return types.AddUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x + y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Add(y), nil + } + } + return types.InvOp2(a, b, opcode.Plus) +} + +func computeMinus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.SubInt64(x, y) + case uint64: + return types.SubIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + return types.SubUintWithInt(x, y) + case uint64: + return types.SubUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x - y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Sub(y), nil + } + } + + return types.InvOp2(a, b, opcode.Minus) +} + +func computeMul(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.MulInt64(x, y) + case uint64: + return types.MulInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.MulInteger(x, y) + case uint64: + return types.MulUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x * y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Mul(y), nil + } + } + + return types.InvOp2(a, b, opcode.Mul) +} + +func computeDiv(a, b interface{}) (interface{}, error) { + // MySQL support integer divison Div and division operator / + // we use opcode.Div for division operator and will use another for integer division later. + // for division operator, we will use float64 for calculation. + switch x := a.(type) { + case float64: + y, err := types.ToFloat64(b) + if err != nil { + return nil, err + } + + if y == 0 { + return nil, nil + } + + return x / y, nil + default: + // the scale of the result is the scale of the first operand plus + // the value of the div_precision_increment system variable (which is 4 by default) + // we will use 4 here + + xa, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + xb, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + if f, _ := xb.Float64(); f == 0 { + // division by zero return null + return nil, nil + } + + return xa.Div(xb), nil + } +} + +func computeMod(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return x % y, nil + case uint64: + if y == 0 { + return nil, nil + } else if x < 0 { + // first is int64, return int64. + return -int64(uint64(-x) % y), nil + } + return int64(uint64(x) % y), nil + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } else if y < 0 { + // first is uint64, return uint64. + return uint64(x % uint64(-y)), nil + } + return x % uint64(y), nil + case uint64: + if y == 0 { + return nil, nil + } + return x % y, nil + } + case float64: + switch y := b.(type) { + case float64: + if y == 0 { + return nil, nil + } + return math.Mod(x, y), nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + xf, _ := x.Float64() + yf, _ := y.Float64() + if yf == 0 { + return nil, nil + } + return math.Mod(xf, yf), nil + } + } + + return types.InvOp2(a, b, opcode.Mod) +} + +func computeIntDiv(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivInt64(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return types.DivIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivUintWithInt(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return x / y, nil + } + } + + // if any is none integer, use decimal to calculate + x, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + y, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + + if f, _ := y.Float64(); f == 0 { + return nil, nil + } + + return x.Div(y).IntPart(), nil +} + +func coerceArithmetic(a interface{}) (interface{}, error) { + switch x := a.(type) { + case string: + // MySQL will convert string to float for arithmetic operation + f, err := types.StrToFloat(x) + if err != nil { + return nil, err + } + return f, err + case mysql.Time: + // if time has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case mysql.Duration: + // if duration has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case []byte: + // []byte is the same as string, converted to float for arithmetic operator. + f, err := types.StrToFloat(string(x)) + if err != nil { + return nil, err + } + return f, err + case mysql.Hex: + return x.ToNumber(), nil + case mysql.Bit: + return x.ToNumber(), nil + case mysql.Enum: + return x.ToNumber(), nil + case mysql.Set: + return x.ToNumber(), nil + default: + return x, nil + } +} + +func columnCount(e ast.ExprNode) (int, error) { + switch x := e.(type) { + case *ast.RowExpr: + n := len(x.Values) + if n <= 1 { + return 0, errors.Errorf("Operand should contain >= 2 columns for Row") + } + return n, nil + case *ast.SubqueryExpr: + return len(x.Query.Fields.Fields), nil + default: + return 1, nil + } +} + +func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error { + l, err := columnCount(e) + if err != nil { + return errors.Trace(err) + } + var n int + for _, arg := range args { + n, err = columnCount(arg) + if err != nil { + return errors.Trace(err) + } + + if n != l { + return errors.Errorf("Operand should contain %d column(s)", l) + } + } + return nil +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 36ccd06c5f..aad7c6d732 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -16,17 +16,40 @@ package optimizer import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/stmt" + "github.com/juju/errors" ) // Compile compiles a ast.Node into a executable statement. func Compile(node ast.Node) (stmt.Statement, error) { + validator := &validator{} + if _, ok := node.Accept(validator); !ok { + return nil, errors.Trace(validator.err) + } + + binder := &InfoBinder{} + if _, ok := node.Accept(validator); !ok { + return nil, errors.Trace(binder.Err) + } + + tpComputer := &typeComputer{} + if _, ok := node.Accept(typeComputer{}); !ok { + return nil, errors.Trace(tpComputer.err) + } + switch v := node.(type) { + case *ast.SelectStmt: + case *ast.SetStmt: return compileSet(v) } return nil, nil } -func compileSet(aset *ast.SetStmt) (stmt.Statement, error) { +func compileSelect(s *ast.SelectStmt) (stmt.Statement, error) { + + return nil, nil +} + +func compileSet(s *ast.SetStmt) (stmt.Statement, error) { return nil, nil } diff --git a/optimizer/typecomputer.go b/optimizer/typecomputer.go new file mode 100644 index 0000000000..3da52497bc --- /dev/null +++ b/optimizer/typecomputer.go @@ -0,0 +1,16 @@ +package optimizer +import "github.com/pingcap/tidb/ast" + +// typeComputer is an ast Visitor that +// Compute types for ast.ExprNode. +type typeComputer struct { + err error +} + +func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) { + return +} \ No newline at end of file diff --git a/optimizer/validator.go b/optimizer/validator.go new file mode 100644 index 0000000000..cd69ebd47a --- /dev/null +++ b/optimizer/validator.go @@ -0,0 +1,16 @@ +package optimizer +import "github.com/pingcap/tidb/ast" + +// validator is an ast.Visitor that validates +// ast parsed from parser. +type validator struct { + err error +} + +func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { + return +} \ No newline at end of file From f3e45697def4f5e9428ba440f0d5ff806f1398a3 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 28 Oct 2015 19:42:08 +0800 Subject: [PATCH 03/25] optimiser: convert to old statement and expression. --- ast/base.go | 11 +- ast/cloner_test.go | 6 +- ast/dml.go | 3 +- ast/expressions.go | 40 +- ast/functions.go | 11 +- ast/parser/parser.y | 22 +- optimizer/aggregator.go | 15 +- optimizer/convert_expr.go | 402 ++++++++++++++++++++ optimizer/convert_stmt.go | 562 +++++++++++++++++++++++++++ optimizer/evaluator.go | 718 +++++++++-------------------------- optimizer/evaluator_binop.go | 527 +++++++++++++++++++++++++ optimizer/infobinder_test.go | 13 + optimizer/optimizer.go | 65 +++- optimizer/typecomputer.go | 20 +- optimizer/validator.go | 20 +- 15 files changed, 1865 insertions(+), 570 deletions(-) create mode 100644 optimizer/convert_expr.go create mode 100644 optimizer/convert_stmt.go create mode 100644 optimizer/evaluator_binop.go diff --git a/ast/base.go b/ast/base.go index 2b25bb7f96..877fa80b51 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,8 +62,7 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node - val interface{} - tp *types.FieldType + types.DataItem } // IsStatic implements Expression interface. @@ -73,22 +72,22 @@ func (en *exprNode) IsStatic() bool { // SetType implements Expression interface. func (en *exprNode) SetType(tp *types.FieldType) { - en.tp = tp + en.Type = tp } // GetType implements Expression interface. func (en *exprNode) GetType() *types.FieldType { - return en.tp + return en.Type } // SetValue implements Expression interface. func (en *exprNode) SetValue(val interface{}) { - en.val = val + en.Data = val } // GetValue implements Expression interface. func (en *exprNode) GetValue() interface{} { - return en.val + return en.Data } type funcNode struct { diff --git a/ast/cloner_test.go b/ast/cloner_test.go index 186ee824a7..f6f567205b 100644 --- a/ast/cloner_test.go +++ b/ast/cloner_test.go @@ -21,7 +21,7 @@ func (ts *testClonerSuite) TestCloner(c *C) { a := &UnaryOperationExpr{ Op: opcode.Not, - V: &UnaryOperationExpr{V: &ValueExpr{Val: true}}, + V: &UnaryOperationExpr{V: NewValueExpr(true)}, } b, ok := a.Accept(cloner) @@ -35,6 +35,6 @@ func (ts *testClonerSuite) TestCloner(c *C) { a3 := a2.(*ValueExpr) b3 := b2.(*ValueExpr) c.Assert(a3, Not(Equals), b3) - c.Assert(a3.Val, Equals, true) - c.Assert(b3.Val, Equals, true) + c.Assert(a3.GetValue(), Equals, true) + c.Assert(b3.GetValue(), Equals, true) } diff --git a/ast/dml.go b/ast/dml.go index f7f11421bd..d46cd86394 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -22,6 +22,7 @@ var ( _ DMLNode = &DeleteStmt{} _ DMLNode = &UpdateStmt{} _ DMLNode = &SelectStmt{} + _ DMLNode = &UnionStmt{} _ Node = &Join{} _ Node = &TableName{} _ Node = &TableSource{} @@ -114,7 +115,7 @@ type TableSource struct { node // Source is the source of the data, can be a TableName, - // a SubQuery, or a JoinNode. + // a SelectStmt, a UnionStmt, or a JoinNode. Source ResultSetNode // AsName is the as name of the table source. diff --git a/ast/expressions.go b/ast/expressions.go index f06109d312..43fc176498 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -15,7 +15,9 @@ package ast import ( "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" ) var ( @@ -26,6 +28,7 @@ var ( _ ExprNode = &CaseExpr{} _ ExprNode = &SubqueryExpr{} _ ExprNode = &CompareSubqueryExpr{} + _ Node = &ColumnName{} _ ExprNode = &ColumnNameExpr{} _ ExprNode = &DefaultExpr{} _ ExprNode = &IdentifierExpr{} @@ -49,6 +52,39 @@ type ValueExpr struct { exprNode } +// NewValueExpr creates a ValueExpr with value, and sets default field type. +func NewValueExpr(value interface{}) *ValueExpr { + ve := &ValueExpr{} + ve.Data = value + // TODO: make it more precise. + switch value.(type) { + case bool, int64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + case uint64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + ve.Type.Flag |= mysql.UnsignedFlag + case string: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = mysql.DefaultCharset + ve.Type.Collate = mysql.DefaultCollationName + case float64: + ve.Type = types.NewFieldType(mysql.TypeDouble) + case []byte: + ve.Type = types.NewFieldType(mysql.TypeBlob) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + case mysql.Bit: + ve.Type = types.NewFieldType(mysql.TypeBit) + case mysql.Hex: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + default: + panic("illegal literal value type") + } + return ve +} + // IsStatic implements ExprNode interface. func (nod *ValueExpr) IsStatic() bool { return true @@ -247,7 +283,7 @@ func (nod *CaseExpr) IsStatic() bool { type SubqueryExpr struct { exprNode // Query is the query SelectNode. - Query *SelectStmt + Query ResultSetNode } // Accept implements Node Accept interface. @@ -261,7 +297,7 @@ func (nod *SubqueryExpr) Accept(v Visitor) (Node, bool) { if !ok { return nod, false } - nod.Query = node.(*SelectStmt) + nod.Query = node.(ResultSetNode) return v.Leave(nod) } diff --git a/ast/functions.go b/ast/functions.go index f3efdbae41..97b9cd1e3c 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -26,7 +26,9 @@ var ( _ FuncNode = &FuncConvertExpr{} _ FuncNode = &FuncCastExpr{} _ FuncNode = &FuncSubstringExpr{} + _ FuncNode = &FuncLocateExpr{} _ FuncNode = &FuncTrimExpr{} + _ FuncNode = &AggregateFuncExpr{} ) // FuncCallExpr is for function expression. @@ -330,15 +332,16 @@ func (nod *FuncTrimExpr) IsStatic() bool { return nod.Str.IsStatic() && nod.RemStr.IsStatic() } -// TypeStar is a special type for "*". -type TypeStar string - +// AggregateFuncExpr represents aggregate function expression. type AggregateFuncExpr struct { funcNode // F is the function name. F string // Args is the function args. Args []ExprNode + // If distinct is true, the function only aggregate distinct values. + // For example, column c1 values are "1", "2", "2", "sum(c1)" is "5", + // but "sum(distinct c1)" is "3". Distinct bool } @@ -357,4 +360,4 @@ func (nod *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { nod.Args[i] = node.(ExprNode) } return v.Leave(nod) -} \ No newline at end of file +} diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 9397163269..af69b11b61 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -944,15 +944,15 @@ NowSym: SignedLiteral: Literal { - $$ = ast.ValueExpr{Val: $1} + $$ = ast.NewValueExpr($1) } | '+' NumLiteral { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: &ast.ValueExpr{Val: $2}} + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: ast.NewValueExpr($2)} } | '-' NumLiteral { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: &ast.ValueExpr{Val: $2}} + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr($2)} } // TODO: support decimal literal @@ -1752,7 +1752,7 @@ Literal: Operand: Literal { - $$ = &ast.ValueExpr{Val: $1} + $$ = ast.NewValueExpr($1) } | ColumnName { @@ -2217,7 +2217,7 @@ FunctionCallAgg: } | "COUNT" '(' DistinctOpt '*' ')' { - args := []ast.ExprNode{&ast.ValueExpr{Val: ast.TypeStar("*")} } + args := []ast.ExprNode{ast.NewValueExpr("*")} $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' @@ -2649,6 +2649,10 @@ TableFactor: { $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } +| '(' UnionStmt ')' TableAsName + { + $$ = &ast.TableSource{Source: $2.(*ast.UnionStmt), AsName: $4.(model.CIStr)} + } | '(' TableRefs ')' { $$ = $2 @@ -2794,6 +2798,14 @@ SubSelect: s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) $$ = &ast.SubqueryExpr{Query: s} } +| '(' UnionStmt ')' + { + s := $2.(*ast.UnionStmt) + src := yylex.(*lexer).src + // See the implemention of yyParse function + s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) + $$ = &ast.SubqueryExpr{Query: s} + } // See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html SelectLockOpt: diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go index 95b9c7cb80..29b2fa3ea3 100644 --- a/optimizer/aggregator.go +++ b/optimizer/aggregator.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer // Aggregator is the interface to @@ -5,7 +18,7 @@ package optimizer type Aggregator interface { // Input add a input value to aggregator. // The input values are accumulated in the aggregator. - Input(in []interface{}) error + Input(in ...interface{}) error // Output use input values to compute the aggregated result. Output() interface{} // Clear clears the input values. diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go new file mode 100644 index 0000000000..4fe74ebf34 --- /dev/null +++ b/optimizer/convert_expr.go @@ -0,0 +1,402 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/subquery" + "github.com/pingcap/tidb/model" + "strings" +) + +func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression.Expression, error) { + expr.Accept(converter) + if converter.err != nil { + return nil, errors.Trace(converter.err) + } + return converter.exprMap[expr], nil +} + +// expressionConverter converts ast expression to +// old expression for transition state. +type expressionConverter struct { + exprMap map[ast.Node]expression.Expression + err error +} + +func newExpressionConverter() *expressionConverter { + return &expressionConverter{ + exprMap: map[ast.Node]expression.Expression{}, + } +} + +// Enter implements ast.Visitor interface. +func (c *expressionConverter) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +// Leave implements ast.Visitor interface. +func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + c.value(v) + case *ast.BetweenExpr: + c.between(v) + case *ast.BinaryOperationExpr: + c.binaryOperation(v) + case *ast.WhenClause: + c.whenClause(v) + case *ast.CaseExpr: + c.caseExpr(v) + case *ast.SubqueryExpr: + c.subquery(v) + case *ast.CompareSubqueryExpr: + c.compareSubquery(v) + case *ast.ColumnNameExpr: + c.columnNameExpr(v) + case *ast.DefaultExpr: + c.defaultExpr(v) + case *ast.IdentifierExpr: + c.identifier(v) + case *ast.ExistsSubqueryExpr: + c.existsSubquery(v) + case *ast.PatternInExpr: + c.patternIn(v) + case *ast.IsNullExpr: + c.isNull(v) + case *ast.IsTruthExpr: + c.isTruth(v) + case *ast.PatternLikeExpr: + c.patternLike(v) + case *ast.ParamMarkerExpr: + c.paramMarker(v) + case *ast.ParenthesesExpr: + c.parentheses(v) + case *ast.PositionExpr: + c.position(v) + case *ast.PatternRegexpExpr: + c.patternRegexp(v) + case *ast.RowExpr: + c.row(v) + case *ast.UnaryOperationExpr: + c.unaryOperation(v) + case *ast.ValuesExpr: + c.values(v) + case *ast.VariableExpr: + c.variable(v) + case *ast.FuncCallExpr: + c.funcCall(v) + case *ast.FuncExtractExpr: + c.funcExtract(v) + case *ast.FuncConvertExpr: + c.funcConvert(v) + case *ast.FuncCastExpr: + c.funcCast(v) + case *ast.FuncSubstringExpr: + c.funcSubstring(v) + case *ast.FuncLocateExpr: + c.funcLocate(v) + case *ast.FuncTrimExpr: + c.funcTrim(v) + case *ast.AggregateFuncExpr: + c.aggregateFunc(v) + } + return in, c.err == nil +} + +func (c *expressionConverter) value(v *ast.ValueExpr) { + c.exprMap[v] = expression.Value{Val: v.GetValue()} +} + +func (c *expressionConverter) between(v *ast.BetweenExpr) { + oldExpr := c.exprMap[v.Expr] + oldLo := c.exprMap[v.Left] + oldHi := c.exprMap[v.Right] + oldBetween, err := expression.NewBetween(oldExpr, oldLo, oldHi, v.Not) + if err != nil { + c.err = err + return + } + c.exprMap[v] = oldBetween +} + +func (c *expressionConverter) binaryOperation(v *ast.BinaryOperationExpr) { + oldeLeft := c.exprMap[v.L] + oldRight := c.exprMap[v.R] + oldBinop := expression.NewBinaryOperation(v.Op, oldeLeft, oldRight) + c.exprMap[v] = oldBinop +} + +func (c *expressionConverter) whenClause(v *ast.WhenClause) { + oldExpr := c.exprMap[v.Expr] + oldResult := c.exprMap[v.Result] + oldWhenClause := &expression.WhenClause{Expr: oldExpr, Result: oldResult} + c.exprMap[v] = oldWhenClause +} + +func (c *expressionConverter) caseExpr(v *ast.CaseExpr) { + oldValue := c.exprMap[v.Value] + oldWhenClauses := make([]*expression.WhenClause, len(v.WhenClauses)) + for i, val := range v.WhenClauses { + oldWhenClauses[i] = c.exprMap[val].(*expression.WhenClause) + } + oldElse := c.exprMap[v.ElseClause] + oldCaseExpr := &expression.FunctionCase{ + Value: oldValue, + WhenClauses: oldWhenClauses, + ElseClause: oldElse, + } + c.exprMap[v] = oldCaseExpr +} + +func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { + oldSubquery := &subquery.SubQuery{} + switch x := v.Query.(type) { + case *ast.SelectStmt: + oldSelect, err := convertSelect(x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldUnion + } + c.exprMap[v] = oldSubquery +} + +func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { + expr := c.exprMap[v.L] + subquery := c.exprMap[v.R] + oldCmpSubquery := expression.NewCompareSubQuery(v.Op, expr, subquery.(expression.SubQuery), v.All) + c.exprMap[v] = oldCmpSubquery +} + +func concatCIStr(ciStrs ...model.CIStr) string { + var originStrs []string + for _, v := range ciStrs { + if v.O != "" { + originStrs = append(originStrs, v.O) + } + } + return strings.Join(originStrs, ".") +} + +func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { + ident := &expression.Ident{} + ident.CIStr = model.NewCIStr(concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name)) + c.exprMap[v] = ident +} + +func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { + oldDefault := &expression.Default{} + if v.Name != nil { + oldDefault.Name = concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name) + } + c.exprMap[v] = oldDefault +} + +func (c *expressionConverter) identifier(v *ast.IdentifierExpr) { + oldIdent := &expression.Ident{} + oldIdent.CIStr = v.Name + c.exprMap[v] = oldIdent +} + +func (c *expressionConverter) existsSubquery(v *ast.ExistsSubqueryExpr) { + subquery := c.exprMap[v.Sel].(expression.SubQuery) + c.exprMap[v] = &expression.ExistsSubQuery{Sel: subquery} +} + +func (c *expressionConverter) patternIn(v *ast.PatternInExpr) { + oldPatternIn := &expression.PatternIn{Not: v.Not} + if v.Sel != nil { + oldPatternIn.Sel = c.exprMap[v.Sel].(expression.SubQuery) + } + oldPatternIn.Expr = c.exprMap[v.Expr] + if v.List != nil { + oldPatternIn.List = make([]expression.Expression, len(v.List)) + for i, v := range v.List { + oldPatternIn.List[i] = c.exprMap[v] + } + } + c.exprMap[v] = oldPatternIn +} + +func (c *expressionConverter) isNull(v *ast.IsNullExpr) { + oldIsNull := &expression.IsNull{Not: v.Not} + oldIsNull.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsNull +} + +func (c *expressionConverter) isTruth(v *ast.IsTruthExpr) { + oldIsTruth := &expression.IsTruth{Not: v.Not, True: v.True} + oldIsTruth.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsTruth +} + +func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { + oldPatternLike := &expression.PatternLike{ + Not: v.Not, + Escape: v.Escape, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternLike +} + +func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { + c.exprMap[v] = &expression.ParamMarker{} +} + +func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { + oldExpr := c.exprMap[v.Expr] + c.exprMap[v] = &expression.PExpr{Expr: oldExpr} +} + +func (c *expressionConverter) position(v *ast.PositionExpr) { + c.exprMap[v] = &expression.Position{N: v.N, Name: v.Name} +} + +func (c *expressionConverter) patternRegexp(v *ast.PatternRegexpExpr) { + oldPatternRegexp := &expression.PatternRegexp{ + Not: v.Not, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternRegexp +} + +func (c *expressionConverter) row(v *ast.RowExpr) { + oldRow := &expression.Row{} + oldRow.Values = make([]expression.Expression, len(v.Values)) + for i, val := range v.Values { + oldRow.Values[i] = c.exprMap[val] + } + c.exprMap[v] = oldRow +} + +func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { + oldUnary := &expression.UnaryOperation{ + Op: v.Op, + V: c.exprMap[v.V], + } + c.exprMap[v] = oldUnary +} + +func (c *expressionConverter) values(v *ast.ValuesExpr) { + nameStr := concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} +} + +func (c *expressionConverter) variable(v *ast.VariableExpr) { + c.exprMap[v] = &expression.Variable{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + } +} + +func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) { + oldCall := &expression.Call{ + F: v.F, + } + oldCall.Args = make([]expression.Expression, len(v.Args)) + for i, val := range v.Args { + oldCall.Args[i] = c.exprMap[val] + } + c.exprMap[v] = oldCall +} + +func (c *expressionConverter) funcExtract(v *ast.FuncExtractExpr) { + oldExtract := &expression.Extract{Unit: v.Unit} + oldExtract.Date = c.exprMap[v.Date] + c.exprMap[v] = oldExtract +} + +func (c *expressionConverter) funcConvert(v *ast.FuncConvertExpr) { + c.exprMap[v] = &expression.FunctionConvert{ + Expr: c.exprMap[v.Expr], + Charset: v.Charset, + } +} + +func (c *expressionConverter) funcCast(v *ast.FuncCastExpr) { + oldCast := &expression.FunctionCast{ + Expr: c.exprMap[v.Expr], + Tp: v.Tp, + } + switch v.FunctionType { + case ast.CastBinaryOperator: + oldCast.FunctionType = expression.BinaryOperator + case ast.CastConvertFunction: + oldCast.FunctionType = expression.ConvertFunction + case ast.CastFunction: + oldCast.FunctionType = expression.CastFunction + } + c.exprMap[v] = oldCast +} + +func (c *expressionConverter) funcSubstring(v *ast.FuncSubstringExpr) { + oldSubstring := &expression.FunctionSubstring{ + Len: c.exprMap[v.Len], + Pos: c.exprMap[v.Pos], + StrExpr: c.exprMap[v.StrExpr], + } + c.exprMap[v] = oldSubstring +} + +func (c *expressionConverter) funcLocate(v *ast.FuncLocateExpr) { + oldLocate := &expression.FunctionLocate{ + Pos: c.exprMap[v.Pos], + Str: c.exprMap[v.Str], + SubStr: c.exprMap[v.SubStr], + } + c.exprMap[v] = oldLocate +} + +func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) { + oldTrim := &expression.FunctionTrim{ + Str: c.exprMap[v.Str], + RemStr: c.exprMap[v.RemStr], + } + switch v.Direction { + case ast.TrimBoth: + oldTrim.Direction = expression.TrimBoth + case ast.TrimBothDefault: + oldTrim.Direction = expression.TrimBothDefault + case ast.TrimLeading: + oldTrim.Direction = expression.TrimLeading + case ast.TrimTrailing: + oldTrim.Direction = expression.TrimTrailing + } + c.exprMap[v] = oldTrim +} + +func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) { + oldAggregate := &expression.Call{ + F: v.F, + Distinct: v.Distinct, + } + for _, val := range v.Args { + oldAggregate.Args = append(oldAggregate.Args, c.exprMap[val]) + } + c.exprMap[v] = oldAggregate +} diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go new file mode 100644 index 0000000000..34e81052b9 --- /dev/null +++ b/optimizer/convert_stmt.go @@ -0,0 +1,562 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/parser/coldef" + "github.com/pingcap/tidb/rset/rsets" + "github.com/pingcap/tidb/stmt" + "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/table" +) + +func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { + oldInsert := &stmts.InsertIntoStmt{ + Text: v.Text(), + } + return oldInsert, nil +} + +func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { + oldDelete := &stmts.DeleteStmt{ + Text: v.Text(), + } + return oldDelete, nil +} + +func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { + oldUpdate := &stmts.UpdateStmt{ + Text: v.Text(), + } + return oldUpdate, nil +} + +func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { + converter := &expressionConverter{ + exprMap: map[ast.Node]expression.Expression{}, + } + oldSelect := &stmts.SelectStmt{} + oldSelect.Distinct = s.Distinct + oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) + for i, val := range s.Fields.Fields { + oldField := &field.Field{} + oldField.AsName = val.AsName.O + var err error + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldSelect.Fields[i] = oldField + } + var err error + if s.From != nil { + oldSelect.From, err = convertJoin(converter, s.From.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Where != nil { + oldSelect.Where = &rsets.WhereRset{} + oldSelect.Where.Expr, err = convertExpr(converter, s.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + + if s.GroupBy != nil { + oldSelect.GroupBy, err = convertGroupBy(converter, s.GroupBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Having != nil { + oldSelect.Having, err = convertHaving(converter, s.Having) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.OrderBy != nil { + oldSelect.OrderBy, err = convertOrderBy(converter, s.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Limit != nil { + if s.Limit.Offset > 0 { + oldSelect.Offset = &rsets.OffsetRset{Count: s.Limit.Offset} + } + if s.Limit.Count > 0 { + oldSelect.Limit = &rsets.LimitRset{Count: s.Limit.Count} + } + } + switch s.LockTp { + case ast.SelectLockForUpdate: + oldSelect.Lock = coldef.SelectLockForUpdate + case ast.SelectLockInShareMode: + oldSelect.Lock = coldef.SelectLockInShareMode + case ast.SelectLockNone: + oldSelect.Lock = coldef.SelectLockNone + } + return oldSelect, nil +} + +func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { + oldUnion := &stmts.UnionStmt{} + oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) + oldUnion.Distincts = make([]bool, len(u.Selects)-1) + if u.Distinct { + for i := range oldUnion.Distincts { + oldUnion.Distincts[i] = true + } + } + for i, val := range u.Selects { + oldSelect, err := convertSelect(val) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.Selects[i] = oldSelect + } + if u.Limit != nil { + if u.Limit.Offset > 0 { + oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} + } + if u.Limit.Count > 0 { + oldUnion.Limit = &rsets.LimitRset{Count: u.Limit.Count} + } + } + // Union order by can not converted to old because it is pushed to select statements. + return oldUnion, nil +} + +func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRset, error) { + oldJoin := &rsets.JoinRset{} + switch join.Tp { + case ast.CrossJoin: + oldJoin.Type = rsets.CrossJoin + case ast.LeftJoin: + oldJoin.Type = rsets.LeftJoin + case ast.RightJoin: + oldJoin.Type = rsets.RightJoin + } + switch l := join.Left.(type) { + case *ast.Join: + oldLeft, err := convertJoin(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + case *ast.TableSource: + oldLeft, err := convertTableSource(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + } + + switch r := join.Left.(type) { + case *ast.Join: + oldRight, err := convertJoin(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case *ast.TableSource: + oldRight, err := convertTableSource(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case nil: + } + + if join.On != nil { + oldOn, err := convertExpr(converter, join.On.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.On = oldOn + } + return oldJoin, nil +} + +func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*rsets.TableSource, error) { + oldTs := &rsets.TableSource{} + oldTs.Name = ts.AsName.O + switch src := ts.Source.(type) { + case *ast.TableName: + oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} + case *ast.SelectStmt: + oldSelect, err := convertSelect(src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldUnion + } + return oldTs, nil +} + +func convertGroupBy(converter *expressionConverter, gb *ast.GroupByClause) (*rsets.GroupByRset, error) { + oldGroupBy := &rsets.GroupByRset{ + By: make([]expression.Expression, len(gb.Items)), + } + for i, val := range gb.Items { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldGroupBy.By[i] = oldExpr + } + return oldGroupBy, nil +} + +func convertHaving(converter *expressionConverter, having *ast.HavingClause) (*rsets.HavingRset, error) { + oldHaving := &rsets.HavingRset{} + oldExpr, err := convertExpr(converter, having.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldHaving.Expr = oldExpr + return oldHaving, nil +} + +func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) (*rsets.OrderByRset, error) { + oldOrderBy := &rsets.OrderByRset{} + oldOrderBy.By = make([]rsets.OrderByItem, len(orderBy.Items)) + for i, val := range orderBy.Items { + oldByItemExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldOrderBy.By[i].Expr = oldByItemExpr + oldOrderBy.By[i].Asc = !val.Desc + } + return oldOrderBy, nil +} + +func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { + oldCreateDatabase := &stmts.CreateDatabaseStmt{ + IfNotExists: v.IfNotExists, + Name: v.Name, + Text: v.Text(), + } + if len(v.Options) != 0 { + oldCreateDatabase.Opt = &coldef.CharsetOpt{} + for _, val := range v.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + oldCreateDatabase.Opt.Chs = val.Value + case ast.DatabaseOptionCollate: + oldCreateDatabase.Opt.Col = val.Value + } + } + } + return oldCreateDatabase, nil +} + +func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { + return &stmts.DropDatabaseStmt{ + IfExists: v.IfExists, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { + oldCreateTable := &stmts.CreateTableStmt{ + Text: v.Text(), + } + return oldCreateTable, nil +} + +func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { + oldDropTable := &stmts.DropTableStmt{ + IfExists: v.IfExists, + Text: v.Text(), + } + oldDropTable.TableIdents = make([]table.Ident, len(v.Tables)) + for i, val := range v.Tables { + oldDropTable.TableIdents[i] = table.Ident{ + Schema: val.Schema, + Name: val.Name, + } + } + return oldDropTable, nil +} + +func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { + oldCreateIndex := &stmts.CreateIndexStmt{ + IndexName: v.IndexName, + Unique: v.Unique, + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + } + oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) + for i, val := range v.IndexColNames { + oldIndexColName := &coldef.IndexColName{ + ColumnName: concatCIStr(val.Column.Schema, val.Column.Table, val.Column.Name), + Length: val.Length, + } + oldCreateIndex.IndexColNames[i] = oldIndexColName + } + return oldCreateIndex, nil +} + +func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { + return &stmts.DropIndexStmt{ + IfExists: v.IfExists, + IndexName: v.IndexName, + Text: v.Text(), + }, nil +} + +func convertAlterTable(v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { + oldAlterTable := &stmts.AlterTableStmt{ + Text: v.Text(), + } + return oldAlterTable, nil +} + +func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { + return &stmts.TruncateTableStmt{ + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + }, nil +} + +func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { + oldExplain := &stmts.ExplainStmt{ + Text: v.Text(), + } + var err error + switch x := v.Stmt.(type) { + case *ast.SelectStmt: + oldExplain.S, err = convertSelect(x) + case *ast.UpdateStmt: + oldExplain.S, err = convertUpdate(x) + case *ast.DeleteStmt: + oldExplain.S, err = convertDelete(x) + case *ast.InsertStmt: + oldExplain.S, err = convertInsert(x) + } + if err != nil { + return nil, errors.Trace(err) + } + return oldExplain, nil +} + +func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { + oldPrepare := &stmts.PreparedStmt{ + InPrepare: true, + Name: v.Name, + SQLText: v.SQLText, + Text: v.Text(), + } + if v.SQLVar != nil { + converter := newExpressionConverter() + oldSQLVar, err := convertExpr(converter, v.SQLVar) + if err != nil { + return nil, errors.Trace(err) + } + oldPrepare.SQLVar = oldSQLVar.(*expression.Variable) + } + return oldPrepare, nil +} + +func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { + return &stmts.DeallocateStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { + oldExec := &stmts.ExecuteStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + } + oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) + converter := newExpressionConverter() + for i, val := range v.UsingVars { + oldVar, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldExec.UsingVars[i] = oldVar + } + return oldExec, nil +} + +func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { + oldShow := &stmts.ShowStmt{ + DBName: v.DBName, + Flag: v.Flag, + Full: v.Full, + GlobalScope: v.GlobalScope, + Text: v.Text(), + } + if v.Table != nil { + oldShow.TableIdent = table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + } + } + if v.Column != nil { + oldShow.ColumnName = concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + } + if v.Where != nil { + converter := newExpressionConverter() + oldWhere, err := convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Where = oldWhere + } + if v.Pattern != nil { + converter := newExpressionConverter() + oldPattern, err := convertExpr(converter, v.Pattern) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Pattern = oldPattern.(*expression.PatternLike) + } + switch v.Tp { + case ast.ShowCharset: + oldShow.Target = stmt.ShowCharset + case ast.ShowCollation: + oldShow.Target = stmt.ShowCollation + case ast.ShowColumns: + oldShow.Target = stmt.ShowColumns + case ast.ShowCreateTable: + oldShow.Target = stmt.ShowCreateTable + case ast.ShowDatabases: + oldShow.Target = stmt.ShowDatabases + case ast.ShowTables: + oldShow.Target = stmt.ShowTables + case ast.ShowEngines: + oldShow.Target = stmt.ShowEngines + case ast.ShowVariables: + oldShow.Target = stmt.ShowVariables + case ast.ShowWarnings: + oldShow.Target = stmt.ShowWarnings + case ast.ShowNone: + oldShow.Target = stmt.ShowNone + } + return oldShow, nil +} + +func convertBegin(v *ast.BeginStmt) (*stmts.BeginStmt, error) { + return &stmts.BeginStmt{ + Text: v.Text(), + }, nil +} + +func convertCommit(v *ast.CommitStmt) (*stmts.CommitStmt, error) { + return &stmts.CommitStmt{ + Text: v.Text(), + }, nil +} + +func convertRollback(v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { + return &stmts.RollbackStmt{ + Text: v.Text(), + }, nil +} + +func convertUse(v *ast.UseStmt) (*stmts.UseStmt, error) { + return &stmts.UseStmt{ + DBName: v.DBName, + Text: v.Text(), + }, nil +} + +func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAssignment) (*stmts.VariableAssignment, error) { + oldValue, err := convertExpr(converter, v.Value) + if err != nil { + return nil, errors.Trace(err) + } + + return &stmts.VariableAssignment{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + Value: oldValue, + Text: v.Text(), + }, nil +} + +func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { + oldSet := &stmts.SetStmt{ + Text: v.Text(), + Variables: make([]*stmts.VariableAssignment, len(v.Variables)), + } + converter := newExpressionConverter() + for i, val := range v.Variables { + oldAssign, err := convertVariableAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldSet.Variables[i] = oldAssign + } + return oldSet, nil +} + +func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { + return &stmts.SetCharsetStmt{ + Charset: v.Charset, + Collate: v.Collate, + Text: v.Text(), + }, nil +} + +func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { + return &stmts.SetPwdStmt{ + User: v.User, + Password: v.Password, + Text: v.Text(), + }, nil +} + +func convertDo(v *ast.DoStmt) (*stmts.DoStmt, error) { + exprConverter := newExpressionConverter() + oldDo := &stmts.DoStmt{ + Text: v.Text(), + Exprs: make([]expression.Expression, len(v.Exprs)), + } + for i, val := range v.Exprs { + oldExpr, err := convertExpr(exprConverter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldDo.Exprs[i] = oldExpr + } + return oldDo, nil +} diff --git a/optimizer/evaluator.go b/optimizer/evaluator.go index 52dab1d6f2..fffb5b53b3 100644 --- a/optimizer/evaluator.go +++ b/optimizer/evaluator.go @@ -14,16 +14,13 @@ package optimizer import ( - "math" - "github.com/juju/errors" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan" - "github.com/pingcap/tidb/util/types" ) +// Evaluator is a ast visitor that evaluates an expression. type Evaluator struct { // columnMap is the map from ColumnName to the position of the rowStack. // It is used to find the value of the column. @@ -51,44 +48,88 @@ type position struct { columnOffset int } +// Enter implements ast.Visitor interface. func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { return } +// Leave implements ast.Visitor interface. func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { switch v := in.(type) { case *ast.ValueExpr: ok = true - case *ast.AggregateFuncExpr: - ok = e.handleAggregateFunc(v) case *ast.BetweenExpr: - ok = e.handleBetween(v) + ok = e.between(v) case *ast.BinaryOperationExpr: - ok = e.handleBinaryOperation(v) + ok = e.binaryOperation(v) + case *ast.WhenClause: + ok = e.whenClause(v) + case *ast.CaseExpr: + ok = e.caseExpr(v) + case *ast.SubqueryExpr: + ok = e.subquery(v) + case *ast.CompareSubqueryExpr: + ok = e.compareSubquery(v) + case *ast.ColumnNameExpr: + ok = e.columnName(v) + case *ast.DefaultExpr: + ok = e.defaultExpr(v) + case *ast.IdentifierExpr: + ok = e.identifier(v) + case *ast.ExistsSubqueryExpr: + ok = e.existsSubquery(v) + case *ast.PatternInExpr: + ok = e.patternIn(v) + case *ast.IsNullExpr: + ok = e.isNull(v) + case *ast.IsTruthExpr: + ok = e.isTruth(v) + case *ast.PatternLikeExpr: + ok = e.patternLike(v) + case *ast.ParamMarkerExpr: + ok = e.paramMarker(v) + case *ast.ParenthesesExpr: + ok = e.parentheses(v) + case *ast.PositionExpr: + ok = e.position(v) + case *ast.PatternRegexpExpr: + ok = e.patternRegexp(v) + case *ast.RowExpr: + ok = e.row(v) + case *ast.UnaryOperationExpr: + ok = e.unaryOperation(v) + case *ast.ValuesExpr: + ok = e.values(v) + case *ast.VariableExpr: + ok = e.variable(v) + case *ast.FuncCallExpr: + ok = e.funcCall(v) + case *ast.FuncExtractExpr: + ok = e.funcExtract(v) + case *ast.FuncConvertExpr: + ok = e.funcConvert(v) + case *ast.FuncCastExpr: + ok = e.funcCast(v) + case *ast.FuncSubstringExpr: + ok = e.funcSubstring(v) + case *ast.FuncLocateExpr: + ok = e.funcLocate(v) + case *ast.FuncTrimExpr: + ok = e.funcTrim(v) + case *ast.AggregateFuncExpr: + ok = e.aggregateFunc(v) } out = in return } -func (e *Evaluator) handleAggregateFunc(v *ast.AggregateFuncExpr) bool { - idx := e.aggregateMap[v] - aggr := e.aggregators[idx] - if e.aggregateDone { - v.SetValue(aggr.Output()) - return true - } - // TODO: currently only single argument aggregate functions are supported. - e.err = aggr.Input(v.Args[0].GetValue()) - return e.err == nil -} - func checkAllOneColumn(exprs ...ast.ExprNode) bool { for _, expr := range exprs { switch v := expr.(type) { case *ast.RowExpr: return false case *ast.SubqueryExpr: - if len(v.Query.Fields.Fields) != 1 { + if len(v.Query.GetResultFields()) != 1 { return false } } @@ -96,7 +137,7 @@ func checkAllOneColumn(exprs ...ast.ExprNode) bool { return true } -func (e *Evaluator) handleBetween(v *ast.BetweenExpr) bool { +func (e *Evaluator) between(v *ast.BetweenExpr) bool { if !checkAllOneColumn(v.Expr, v.Left, v.Right) { e.err = errors.Errorf("Operand should contain 1 column(s)") return false @@ -108,520 +149,19 @@ func (e *Evaluator) handleBetween(v *ast.BetweenExpr) bool { if v.Not { // v < lv || v > rv op = opcode.OrOr - l = &ast.BinaryOperationExpr{opcode.LT, v.Expr, v.Left} - r = &ast.BinaryOperationExpr{opcode.GT, v.Expr, v.Right} + l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right} } else { // v >= lv && v <= rv - l = &ast.BinaryOperationExpr{opcode.GE, v.Expr, v.Left} - r = &ast.BinaryOperationExpr{opcode.LE, v.Expr, v.Right} + l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right} } - ret := &ast.BinaryOperationExpr{op, l, r} + ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r} ret.Accept(e) return e.err == nil } -const ( - zeroI64 int64 = 0 - oneI64 int64 = 1 -) - -func (e *Evaluator) handleBinaryOperation(o *ast.BinaryOperationExpr) bool { - // all operands must have same column. - if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { - return nil, false - } - - // row constructor only supports comparison operation. - switch o.Op { - case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: - default: - if !checkAllOneColumn(o.L) { - return nil, errors.Errorf("Operand should contain 1 column(s)") - } - } - - leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - if leftVal == nil || rightVal == nil { - o.SetValue(nil) - return true - } - - switch o.Op { - case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: - return e.handleLogicOperation(o) - case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: - return e.handleComparisonOp(o) - case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: - // TODO: MySQL doesn't support and not, we should remove it later. - return e.handleBitOp(o) - case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: - return e.handleArithmeticOp(o) - default: - panic("should never happen") - } -} - -func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { - leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - if leftVal == nil || rightVal == nil { - o.SetValue(nil) - return true - } - var boolVal bool - switch o.Op { - case opcode.AndAnd: - boolVal = leftVal != zeroI64 && rightVal != zeroI64 - case opcode.OrOr: - boolVal = leftVal != zeroI64 || rightVal != zeroI64 - case opcode.LogicXor: - boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) - default: - panic("should never happen") - } - if boolVal { - o.SetValue(oneI64) - } else { - o.SetValue(zeroI64) - } - return true -} - -func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { - a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) - if types.IsNil(a) || types.IsNil(b) { - // for <=>, if a and b are both nil, return true. - // if a or b is nil, return false. - if o.Op == opcode.NullEQ { - if types.IsNil(a) || types.IsNil(b) { - o.SetValue(oneI64) - } else { - o.SetValue(zeroI64) - } - } else { - o.SetValue(nil) - } - return true - } - - n, err := types.Compare(a, b) - if err != nil { - e.err = err - return false - } - - r, err := getCompResult(o.Op, n) - if err != nil { - e.err = err - return false - } - if r { - o.SetValue(oneI64) - } else { - o.SetValue(zeroI64) - } - return true -} - -func getCompResult(op opcode.Op, value int) (bool, error) { - switch op { - case opcode.LT: - return value < 0, nil - case opcode.LE: - return value <= 0, nil - case opcode.GE: - return value >= 0, nil - case opcode.GT: - return value > 0, nil - case opcode.EQ: - return value == 0, nil - case opcode.NE: - return value != 0, nil - case opcode.NullEQ: - return value == 0, nil - default: - return false, errors.Errorf("invalid op %v in comparision operation", op) - } -} - -func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { - a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) - - if types.IsNil(a) || types.IsNil(b) { - return nil, nil - } - - x, err := types.ToInt64(a) - if err != nil { - e.err = err - return false - } - - y, err := types.ToInt64(b) - if err != nil { - e.err = err - return false - } - - // use a int64 for bit operator, return uint64 - switch o.Op { - case opcode.And: - o.SetValue(uint64(x & y)) - case opcode.Or: - o.SetValue(uint64(x | y)) - case opcode.Xor: - o.SetValue(uint64(x ^ y)) - case opcode.RightShift: - o.SetValue(uint64(x) >> uint64(y)) - case opcode.LeftShift: - o.SetValue(uint64(x) << uint64(y)) - default: - e.err = errors.Errorf("invalid op %v in bit operation", o.Op) - return false - } - return true -} - -func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { - a, err := coerceArithmetic(o.L.GetValue()) - if err != nil { - e.err = err - return false - } - - b, err := coerceArithmetic(o.R.GetValue()) - if err != nil { - e.err = err - return false - } - a, b = types.Coerce(a, b) - - if a == nil || b == nil { - // TODO: for <=>, if a and b are both nil, return true - o.SetValue(nil) - return true - } - - // TODO: support logic division DIV - var result interface{} - switch o.Op { - case opcode.Plus: - result, e.err = computePlus(a, b) - case opcode.Minus: - result, e.err = computeMinus(a, b) - case opcode.Mul: - result, e.err = computeMul(a, b) - case opcode.Div: - result, e.err = computeDiv(a, b) - case opcode.Mod: - result, e.err = computeMod(a, b) - case opcode.IntDiv: - result, e.err = computeIntDiv(a, b) - default: - e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) - return false - } - o.SetValue(result) - return e.err == nil -} - -func computePlus(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - return types.AddInt64(x, y) - case uint64: - return types.AddInteger(y, x) - } - case uint64: - switch y := b.(type) { - case int64: - return types.AddInteger(x, y) - case uint64: - return types.AddUint64(x, y) - } - case float64: - switch y := b.(type) { - case float64: - return x + y, nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Add(y), nil - } - } - return types.InvOp2(a, b, opcode.Plus) -} - -func computeMinus(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - return types.SubInt64(x, y) - case uint64: - return types.SubIntWithUint(x, y) - } - case uint64: - switch y := b.(type) { - case int64: - return types.SubUintWithInt(x, y) - case uint64: - return types.SubUint64(x, y) - } - case float64: - switch y := b.(type) { - case float64: - return x - y, nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Sub(y), nil - } - } - - return types.InvOp2(a, b, opcode.Minus) -} - -func computeMul(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - return types.MulInt64(x, y) - case uint64: - return types.MulInteger(y, x) - } - case uint64: - switch y := b.(type) { - case int64: - return types.MulInteger(x, y) - case uint64: - return types.MulUint64(x, y) - } - case float64: - switch y := b.(type) { - case float64: - return x * y, nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Mul(y), nil - } - } - - return types.InvOp2(a, b, opcode.Mul) -} - -func computeDiv(a, b interface{}) (interface{}, error) { - // MySQL support integer divison Div and division operator / - // we use opcode.Div for division operator and will use another for integer division later. - // for division operator, we will use float64 for calculation. - switch x := a.(type) { - case float64: - y, err := types.ToFloat64(b) - if err != nil { - return nil, err - } - - if y == 0 { - return nil, nil - } - - return x / y, nil - default: - // the scale of the result is the scale of the first operand plus - // the value of the div_precision_increment system variable (which is 4 by default) - // we will use 4 here - - xa, err := types.ToDecimal(a) - if err != nil { - return nil, err - } - - xb, err := types.ToDecimal(b) - if err != nil { - return nil, err - } - if f, _ := xb.Float64(); f == 0 { - // division by zero return null - return nil, nil - } - - return xa.Div(xb), nil - } -} - -func computeMod(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } - return x % y, nil - case uint64: - if y == 0 { - return nil, nil - } else if x < 0 { - // first is int64, return int64. - return -int64(uint64(-x) % y), nil - } - return int64(uint64(x) % y), nil - } - case uint64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } else if y < 0 { - // first is uint64, return uint64. - return uint64(x % uint64(-y)), nil - } - return x % uint64(y), nil - case uint64: - if y == 0 { - return nil, nil - } - return x % y, nil - } - case float64: - switch y := b.(type) { - case float64: - if y == 0 { - return nil, nil - } - return math.Mod(x, y), nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - xf, _ := x.Float64() - yf, _ := y.Float64() - if yf == 0 { - return nil, nil - } - return math.Mod(xf, yf), nil - } - } - - return types.InvOp2(a, b, opcode.Mod) -} - -func computeIntDiv(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } - return types.DivInt64(x, y) - case uint64: - if y == 0 { - return nil, nil - } - return types.DivIntWithUint(x, y) - } - case uint64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } - return types.DivUintWithInt(x, y) - case uint64: - if y == 0 { - return nil, nil - } - return x / y, nil - } - } - - // if any is none integer, use decimal to calculate - x, err := types.ToDecimal(a) - if err != nil { - return nil, err - } - - y, err := types.ToDecimal(b) - if err != nil { - return nil, err - } - - if f, _ := y.Float64(); f == 0 { - return nil, nil - } - - return x.Div(y).IntPart(), nil -} - -func coerceArithmetic(a interface{}) (interface{}, error) { - switch x := a.(type) { - case string: - // MySQL will convert string to float for arithmetic operation - f, err := types.StrToFloat(x) - if err != nil { - return nil, err - } - return f, err - case mysql.Time: - // if time has no precision, return int64 - v := x.ToNumber() - if x.Fsp == 0 { - return v.IntPart(), nil - } - return v, nil - case mysql.Duration: - // if duration has no precision, return int64 - v := x.ToNumber() - if x.Fsp == 0 { - return v.IntPart(), nil - } - return v, nil - case []byte: - // []byte is the same as string, converted to float for arithmetic operator. - f, err := types.StrToFloat(string(x)) - if err != nil { - return nil, err - } - return f, err - case mysql.Hex: - return x.ToNumber(), nil - case mysql.Bit: - return x.ToNumber(), nil - case mysql.Enum: - return x.ToNumber(), nil - case mysql.Set: - return x.ToNumber(), nil - default: - return x, nil - } -} - func columnCount(e ast.ExprNode) (int, error) { switch x := e.(type) { case *ast.RowExpr: @@ -631,7 +171,7 @@ func columnCount(e ast.ExprNode) (int, error) { } return n, nil case *ast.SubqueryExpr: - return len(x.Query.Fields.Fields), nil + return len(x.Query.GetResultFields()), nil default: return 1, nil } @@ -655,3 +195,123 @@ func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error { } return nil } + +func (e *Evaluator) whenClause(v *ast.WhenClause) bool { + return true +} + +func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { + return true +} + +func (e *Evaluator) subquery(v *ast.SubqueryExpr) bool { + return true +} + +func (e *Evaluator) compareSubquery(v *ast.CompareSubqueryExpr) bool { + return true +} + +func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { + return true +} + +func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool { + return true +} + +func (e *Evaluator) identifier(v *ast.IdentifierExpr) bool { + return true +} + +func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { + return true +} + +func (e *Evaluator) patternIn(v *ast.PatternInExpr) bool { + return true +} + +func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { + return true +} + +func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { + return true +} + +func (e *Evaluator) patternLike(v *ast.PatternLikeExpr) bool { + return true +} + +func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { + return true +} + +func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { + return true +} + +func (e *Evaluator) position(v *ast.PositionExpr) bool { + return true +} + +func (e *Evaluator) patternRegexp(v *ast.PatternRegexpExpr) bool { + return true +} + +func (e *Evaluator) row(v *ast.RowExpr) bool { + return true +} + +func (e *Evaluator) unaryOperation(v *ast.UnaryOperationExpr) bool { + return true +} + +func (e *Evaluator) values(v *ast.ValuesExpr) bool { + return true +} + +func (e *Evaluator) variable(v *ast.VariableExpr) bool { + return true +} + +func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool { + return true +} + +func (e *Evaluator) funcExtract(v *ast.FuncExtractExpr) bool { + return true +} + +func (e *Evaluator) funcConvert(v *ast.FuncConvertExpr) bool { + return true +} + +func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool { + return true +} + +func (e *Evaluator) funcSubstring(v *ast.FuncSubstringExpr) bool { + return true +} + +func (e *Evaluator) funcLocate(v *ast.FuncLocateExpr) bool { + return true +} + +func (e *Evaluator) funcTrim(v *ast.FuncTrimExpr) bool { + return true +} + +func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { + idx := e.aggregateMap[v] + aggr := e.aggregators[idx] + if e.aggregateDone { + v.SetValue(aggr.Output()) + return true + } + // TODO: currently only single argument aggregate functions are supported. + e.err = aggr.Input(v.Args[0].GetValue()) + return e.err == nil +} diff --git a/optimizer/evaluator_binop.go b/optimizer/evaluator_binop.go new file mode 100644 index 0000000000..0cda326144 --- /dev/null +++ b/optimizer/evaluator_binop.go @@ -0,0 +1,527 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +const ( + zeroI64 int64 = 0 + oneI64 int64 = 1 +) + +func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool { + // all operands must have same column. + if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { + return false + } + + // row constructor only supports comparison operation. + switch o.Op { + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + default: + if !checkAllOneColumn(o.L) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + } + + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + + switch o.Op { + case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: + return e.handleLogicOperation(o) + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + return e.handleComparisonOp(o) + case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: + // TODO: MySQL doesn't support and not, we should remove it later. + return e.handleBitOp(o) + case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: + return e.handleArithmeticOp(o) + default: + panic("should never happen") + } +} + +func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + var boolVal bool + switch o.Op { + case opcode.AndAnd: + boolVal = leftVal != zeroI64 && rightVal != zeroI64 + case opcode.OrOr: + boolVal = leftVal != zeroI64 || rightVal != zeroI64 + case opcode.LogicXor: + boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) + default: + panic("should never happen") + } + if boolVal { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + if types.IsNil(a) || types.IsNil(b) { + // for <=>, if a and b are both nil, return true. + // if a or b is nil, return false. + if o.Op == opcode.NullEQ { + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + } else { + o.SetValue(nil) + } + return true + } + + n, err := types.Compare(a, b) + if err != nil { + e.err = err + return false + } + + r, err := getCompResult(o.Op, n) + if err != nil { + e.err = err + return false + } + if r { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + case opcode.NullEQ: + return value == 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + +func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(nil) + return true + } + + x, err := types.ToInt64(a) + if err != nil { + e.err = err + return false + } + + y, err := types.ToInt64(b) + if err != nil { + e.err = err + return false + } + + // use a int64 for bit operator, return uint64 + switch o.Op { + case opcode.And: + o.SetValue(uint64(x & y)) + case opcode.Or: + o.SetValue(uint64(x | y)) + case opcode.Xor: + o.SetValue(uint64(x ^ y)) + case opcode.RightShift: + o.SetValue(uint64(x) >> uint64(y)) + case opcode.LeftShift: + o.SetValue(uint64(x) << uint64(y)) + default: + e.err = errors.Errorf("invalid op %v in bit operation", o.Op) + return false + } + return true +} + +func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { + a, err := coerceArithmetic(o.L.GetValue()) + if err != nil { + e.err = err + return false + } + + b, err := coerceArithmetic(o.R.GetValue()) + if err != nil { + e.err = err + return false + } + a, b = types.Coerce(a, b) + + if a == nil || b == nil { + // TODO: for <=>, if a and b are both nil, return true + o.SetValue(nil) + return true + } + + // TODO: support logic division DIV + var result interface{} + switch o.Op { + case opcode.Plus: + result, e.err = computePlus(a, b) + case opcode.Minus: + result, e.err = computeMinus(a, b) + case opcode.Mul: + result, e.err = computeMul(a, b) + case opcode.Div: + result, e.err = computeDiv(a, b) + case opcode.Mod: + result, e.err = computeMod(a, b) + case opcode.IntDiv: + result, e.err = computeIntDiv(a, b) + default: + e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) + return false + } + o.SetValue(result) + return e.err == nil +} + +func computePlus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.AddInt64(x, y) + case uint64: + return types.AddInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.AddInteger(x, y) + case uint64: + return types.AddUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x + y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Add(y), nil + } + } + return types.InvOp2(a, b, opcode.Plus) +} + +func computeMinus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.SubInt64(x, y) + case uint64: + return types.SubIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + return types.SubUintWithInt(x, y) + case uint64: + return types.SubUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x - y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Sub(y), nil + } + } + + return types.InvOp2(a, b, opcode.Minus) +} + +func computeMul(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.MulInt64(x, y) + case uint64: + return types.MulInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.MulInteger(x, y) + case uint64: + return types.MulUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x * y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Mul(y), nil + } + } + + return types.InvOp2(a, b, opcode.Mul) +} + +func computeDiv(a, b interface{}) (interface{}, error) { + // MySQL support integer divison Div and division operator / + // we use opcode.Div for division operator and will use another for integer division later. + // for division operator, we will use float64 for calculation. + switch x := a.(type) { + case float64: + y, err := types.ToFloat64(b) + if err != nil { + return nil, err + } + + if y == 0 { + return nil, nil + } + + return x / y, nil + default: + // the scale of the result is the scale of the first operand plus + // the value of the div_precision_increment system variable (which is 4 by default) + // we will use 4 here + + xa, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + xb, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + if f, _ := xb.Float64(); f == 0 { + // division by zero return null + return nil, nil + } + + return xa.Div(xb), nil + } +} + +func computeMod(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return x % y, nil + case uint64: + if y == 0 { + return nil, nil + } else if x < 0 { + // first is int64, return int64. + return -int64(uint64(-x) % y), nil + } + return int64(uint64(x) % y), nil + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } else if y < 0 { + // first is uint64, return uint64. + return uint64(x % uint64(-y)), nil + } + return x % uint64(y), nil + case uint64: + if y == 0 { + return nil, nil + } + return x % y, nil + } + case float64: + switch y := b.(type) { + case float64: + if y == 0 { + return nil, nil + } + return math.Mod(x, y), nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + xf, _ := x.Float64() + yf, _ := y.Float64() + if yf == 0 { + return nil, nil + } + return math.Mod(xf, yf), nil + } + } + + return types.InvOp2(a, b, opcode.Mod) +} + +func computeIntDiv(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivInt64(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return types.DivIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivUintWithInt(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return x / y, nil + } + } + + // if any is none integer, use decimal to calculate + x, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + y, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + + if f, _ := y.Float64(); f == 0 { + return nil, nil + } + + return x.Div(y).IntPart(), nil +} + +func coerceArithmetic(a interface{}) (interface{}, error) { + switch x := a.(type) { + case string: + // MySQL will convert string to float for arithmetic operation + f, err := types.StrToFloat(x) + if err != nil { + return nil, err + } + return f, err + case mysql.Time: + // if time has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case mysql.Duration: + // if duration has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case []byte: + // []byte is the same as string, converted to float for arithmetic operator. + f, err := types.StrToFloat(string(x)) + if err != nil { + return nil, err + } + return f, err + case mysql.Hex: + return x.ToNumber(), nil + case mysql.Bit: + return x.ToNumber(), nil + case mysql.Enum: + return x.ToNumber(), nil + case mysql.Set: + return x.ToNumber(), nil + default: + return x, nil + } +} diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go index f2109e0c17..988d6dfaa0 100644 --- a/optimizer/infobinder_test.go +++ b/optimizer/infobinder_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer_test import ( diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index aad7c6d732..cec754d164 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -14,9 +14,9 @@ package optimizer import ( + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/stmt" - "github.com/juju/errors" ) // Compile compiles a ast.Node into a executable statement. @@ -32,24 +32,63 @@ func Compile(node ast.Node) (stmt.Statement, error) { } tpComputer := &typeComputer{} - if _, ok := node.Accept(typeComputer{}); !ok { + if _, ok := node.Accept(tpComputer); !ok { return nil, errors.Trace(tpComputer.err) } switch v := node.(type) { + case *ast.InsertStmt: + return convertInsert(v) + case *ast.DeleteStmt: + return convertDelete(v) + case *ast.UpdateStmt: + return convertUpdate(v) case *ast.SelectStmt: - + return convertSelect(v) + case *ast.UnionStmt: + return convertUnion(v) + case *ast.CreateDatabaseStmt: + return convertCreateDatabase(v) + case *ast.DropDatabaseStmt: + return convertDropDatabase(v) + case *ast.CreateTableStmt: + return convertCreateTable(v) + case *ast.DropTableStmt: + return convertDropTable(v) + case *ast.CreateIndexStmt: + return convertCreateIndex(v) + case *ast.DropIndexStmt: + return convertDropIndex(v) + case *ast.AlterTableStmt: + return convertAlterTable(v) + case *ast.TruncateTableStmt: + return convertTruncateTable(v) + case *ast.ExplainStmt: + return convertExplain(v) + case *ast.PrepareStmt: + return convertPrepare(v) + case *ast.DeallocateStmt: + return convertDeallocate(v) + case *ast.ExecuteStmt: + return convertExecute(v) + case *ast.ShowStmt: + return convertShow(v) + case *ast.BeginStmt: + return convertBegin(v) + case *ast.CommitStmt: + return convertCommit(v) + case *ast.RollbackStmt: + return convertRollback(v) + case *ast.UseStmt: + return convertUse(v) case *ast.SetStmt: - return compileSet(v) + return convertSet(v) + case *ast.SetCharsetStmt: + return convertSetCharset(v) + case *ast.SetPwdStmt: + return convertSetPwd(v) + case *ast.DoStmt: + return convertDo(v) } return nil, nil } - -func compileSelect(s *ast.SelectStmt) (stmt.Statement, error) { - - return nil, nil -} - -func compileSet(s *ast.SetStmt) (stmt.Statement, error) { - return nil, nil -} diff --git a/optimizer/typecomputer.go b/optimizer/typecomputer.go index 3da52497bc..d1862d1380 100644 --- a/optimizer/typecomputer.go +++ b/optimizer/typecomputer.go @@ -1,4 +1,18 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer + import "github.com/pingcap/tidb/ast" // typeComputer is an ast Visitor that @@ -8,9 +22,9 @@ type typeComputer struct { } func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { - return + return in, false } func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) { - return -} \ No newline at end of file + return in, true +} diff --git a/optimizer/validator.go b/optimizer/validator.go index cd69ebd47a..f90a757ab8 100644 --- a/optimizer/validator.go +++ b/optimizer/validator.go @@ -1,4 +1,18 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer + import "github.com/pingcap/tidb/ast" // validator is an ast.Visitor that validates @@ -8,9 +22,9 @@ type validator struct { } func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { - return + return in, false } func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { - return -} \ No newline at end of file + return in, true +} From edcbbc84ff83f811e4708fadeabb74b14fdcb0e1 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 29 Oct 2015 15:33:54 +0800 Subject: [PATCH 04/25] optimizer: implement ddl statement convert functions. --- ast/dml.go | 4 +- ast/parser/parser.y | 4 + optimizer/convert_expr.go | 18 +-- optimizer/convert_stmt.go | 274 +++++++++++++++++++++++++++++++++++++- 4 files changed, 284 insertions(+), 16 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index d46cd86394..0225af0462 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -564,7 +564,7 @@ type InsertStmt struct { Setlist []*Assignment Priority int OnDuplicate []*Assignment - Select *SelectStmt + Select ResultSetNode } // Accept implements Node Accept interface. @@ -616,7 +616,7 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { if !ok { return nod, false } - nod.Select = node.(*SelectStmt) + nod.Select = node.(ResultSetNode) } return v.Leave(nod) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index af69b11b61..6f096b8687 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1669,6 +1669,10 @@ InsertRest: { $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} } +| UnionStmt + { + $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} + } | "SET" ColumnSetValueList { $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 4fe74ebf34..41c78a815e 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -190,26 +190,28 @@ func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { c.exprMap[v] = oldCmpSubquery } -func concatCIStr(ciStrs ...model.CIStr) string { +func joinColumnName(columnName *ast.ColumnName) string { var originStrs []string - for _, v := range ciStrs { - if v.O != "" { - originStrs = append(originStrs, v.O) - } + if columnName.Schema.O != "" { + originStrs = append(originStrs, columnName.Schema.O) } + if columnName.Table.O != "" { + originStrs = append(originStrs, columnName.Table.O) + } + originStrs = append(originStrs, columnName.Name.O) return strings.Join(originStrs, ".") } func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { ident := &expression.Ident{} - ident.CIStr = model.NewCIStr(concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name)) + ident.CIStr = model.NewCIStr(joinColumnName(v.Name)) c.exprMap[v] = ident } func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { oldDefault := &expression.Default{} if v.Name != nil { - oldDefault.Name = concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name) + oldDefault.Name = joinColumnName(v.Name) } c.exprMap[v] = oldDefault } @@ -302,7 +304,7 @@ func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { } func (c *expressionConverter) values(v *ast.ValuesExpr) { - nameStr := concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + nameStr := joinColumnName(v.Column) c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} } diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 34e81052b9..036bfebefd 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -25,23 +25,153 @@ import ( "github.com/pingcap/tidb/table" ) +func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expression.Assignment, error) { + oldAssign := &expression.Assignment{ + ColName: joinColumnName(v.Column), + } + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldAssign.Expr = oldExpr + return oldAssign, nil +} + func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { oldInsert := &stmts.InsertIntoStmt{ - Text: v.Text(), + Priority: v.Priority, + Text: v.Text(), + } + tableName := v.Table.TableRefs.Left.(*ast.TableName) + oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + for _, val := range v.Columns { + oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) + } + converter := newExpressionConverter() + for _, row := range v.Lists { + var oldRow []expression.Expression + for _, val := range row { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldRow = append(oldRow, oldExpr) + } + oldInsert.Lists = append(oldInsert.Lists, oldRow) + } + for _, assign := range v.Setlist { + oldAssign, err := convertAssignment(converter, assign) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.Setlist = append(oldInsert.Setlist, oldAssign) + } + for _, onDup := range v.OnDuplicate { + oldOnDup, err := convertAssignment(converter, onDup) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, *oldOnDup) + } + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + oldInsert.Sel, err = convertSelect(x) + case *ast.UnionStmt: + oldInsert.Sel, err = convertUnion(x) + } + if err != nil { + return nil, errors.Trace(err) + } } return oldInsert, nil } func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { oldDelete := &stmts.DeleteStmt{ - Text: v.Text(), + BeforeFrom: v.BeforeFrom, + Ignore: v.Ignore, + LowPriority: v.LowPriority, + MultiTable: v.MultiTable, + Quick: v.Quick, + Text: v.Text(), + } + converter := newExpressionConverter() + oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Refs = oldRefs + for _, val := range v.Tables { + tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} + oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + } + if v.Where != nil { + oldDelete.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Order != nil { + orderRset := &rsets.OrderByRset{} + for _, val := range v.Order { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} + orderRset.By = append(orderRset.By, orderItem) + } + oldDelete.Order = orderRset + } + if v.Limit != nil { + oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} } return oldDelete, nil } func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { oldUpdate := &stmts.UpdateStmt{ - Text: v.Text(), + Ignore: v.Ignore, + MultipleTable: v.MultipleTable, + LowPriority: v.LowPriority, + Text: v.Text(), + } + converter := newExpressionConverter() + var err error + oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + if v.Where != nil { + oldUpdate.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + for _, val := range v.List { + var oldAssign *expression.Assignment + oldAssign, err = convertAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.List = append(oldUpdate.List, *oldAssign) + } + if v.Order != nil { + orderRset := &rsets.OrderByRset{} + for _, val := range v.Order { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} + orderRset.By = append(orderRset.By, orderItem) + } + oldUpdate.Order = orderRset + } + if v.Limit != nil { + oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} } return oldUpdate, nil } @@ -282,9 +412,141 @@ func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, erro }, nil } +func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (*coldef.ConstraintOpt, error) { + oldColumnOpt := &coldef.ConstraintOpt{} + switch v.Tp { + case ast.ColumnOptionAutoIncrement: + oldColumnOpt.Tp = coldef.ConstrAutoIncrement + case ast.ColumnOptionComment: + oldColumnOpt.Tp = coldef.ConstrComment + case ast.ColumnOptionDefaultValue: + oldColumnOpt.Tp = coldef.ConstrDefaultValue + case ast.ColumnOptionIndex: + oldColumnOpt.Tp = coldef.ConstrIndex + case ast.ColumnOptionKey: + oldColumnOpt.Tp = coldef.ConstrKey + case ast.ColumnOptionFulltext: + oldColumnOpt.Tp = coldef.ConstrFulltext + case ast.ColumnOptionNotNull: + oldColumnOpt.Tp = coldef.ConstrNotNull + case ast.ColumnOptionNoOption: + oldColumnOpt.Tp = coldef.ConstrNoConstr + case ast.ColumnOptionOnUpdate: + oldColumnOpt.Tp = coldef.ConstrOnUpdate + case ast.ColumnOptionPrimaryKey: + oldColumnOpt.Tp = coldef.ConstrPrimaryKey + case ast.ColumnOptionNull: + oldColumnOpt.Tp = coldef.ConstrNull + case ast.ColumnOptionUniq: + oldColumnOpt.Tp = coldef.ConstrUniq + case ast.ColumnOptionUniqIndex: + oldColumnOpt.Tp = coldef.ConstrUniqIndex + case ast.ColumnOptionUniqKey: + oldColumnOpt.Tp = coldef.ConstrUniqKey + } + if v.Expr != nil { + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldColumnOpt.Evalue = oldExpr + } + return oldColumnOpt, nil +} + +func convertColumnDef(converter *expressionConverter, v *ast.ColumnDef) (*coldef.ColumnDef, error) { + oldColDef := &coldef.ColumnDef{ + Name: v.Name.Name.O, + Tp: v.Tp, + } + for _, val := range v.Options { + oldOpt, err := convertColumnOption(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldColDef.Constraints = append(oldColDef.Constraints, oldOpt) + } + return oldColDef, nil +} + +func convertIndexColNames(v []*ast.IndexColName) (out []*coldef.IndexColName) { + for _, val := range v { + oldIndexColKey := &coldef.IndexColName{ + ColumnName: val.Column.Name.O, + Length: val.Length, + } + out = append(out, oldIndexColKey) + } + return +} + +func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*coldef.TableConstraint, error) { + oldConstraint := &coldef.TableConstraint{ConstrName: v.Name} + switch v.Tp { + case ast.ConstraintNoConstraint: + oldConstraint.Tp = coldef.ConstrNoConstr + case ast.ConstraintPrimaryKey: + oldConstraint.Tp = coldef.ConstrPrimaryKey + case ast.ConstraintKey: + oldConstraint.Tp = coldef.ConstrKey + case ast.ConstraintIndex: + oldConstraint.Tp = coldef.ConstrIndex + case ast.ConstraintUniq: + oldConstraint.Tp = coldef.ConstrUniq + case ast.ConstraintUniqKey: + oldConstraint.Tp = coldef.ConstrUniqKey + case ast.ConstraintUniqIndex: + oldConstraint.Tp = coldef.ConstrUniqIndex + case ast.ConstraintForeignKey: + oldConstraint.Tp = coldef.ConstrForeignKey + case ast.ConstraintFulltext: + oldConstraint.Tp = coldef.ConstrFulltext + } + oldConstraint.Keys = convertIndexColNames(v.Keys) + if v.Refer != nil { + oldConstraint.Refer = &coldef.ReferenceDef{ + TableIdent: table.Ident{Schema: v.Refer.Table.Schema, Name: v.Refer.Table.Name}, + IndexColNames: convertIndexColNames(v.Refer.IndexColNames), + } + } + return oldConstraint, nil +} + func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { oldCreateTable := &stmts.CreateTableStmt{ - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + converter := newExpressionConverter() + for _, val := range v.Cols { + oldColDef, err := convertColumnDef(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Cols = append(oldCreateTable.Cols, oldColDef) + } + for _, val := range v.Constraints { + oldConstr, err := convertConstraint(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Constraints = append(oldCreateTable.Constraints, oldConstr) + } + if len(v.Options) != 0 { + oldTableOpt := &coldef.TableOption{} + for _, val := range v.Options { + switch val.Tp { + case ast.TableOptionEngine: + oldTableOpt.Engine = val.StrValue + case ast.TableOptionCharset: + oldTableOpt.Charset = val.StrValue + case ast.TableOptionCollate: + oldTableOpt.Collate = val.StrValue + case ast.TableOptionAutoIncrement: + oldTableOpt.AutoIncrement = val.UintValue + } + } + oldCreateTable.Opt = oldTableOpt } return oldCreateTable, nil } @@ -317,7 +579,7 @@ func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) for i, val := range v.IndexColNames { oldIndexColName := &coldef.IndexColName{ - ColumnName: concatCIStr(val.Column.Schema, val.Column.Table, val.Column.Name), + ColumnName: joinColumnName(val.Column), Length: val.Length, } oldCreateIndex.IndexColNames[i] = oldIndexColName @@ -430,7 +692,7 @@ func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { } } if v.Column != nil { - oldShow.ColumnName = concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + oldShow.ColumnName = joinColumnName(v.Column) } if v.Where != nil { converter := newExpressionConverter() From 7c7473969f99aa21143bd4a674a0a69de480a12c Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 29 Oct 2015 20:29:02 +0800 Subject: [PATCH 05/25] tidb: switch to use ast parser. --- ast/cloner.go | 7 +- ast/dml.go | 8 ++ ast/expressions.go | 5 +- ast/parser/parser.y | 101 +++++++++--------- ast/parser/scanner.l | 13 ++- optimizer/convert_expr.go | 4 +- optimizer/convert_stmt.go | 208 ++++++++++++++++++++++++++++---------- optimizer/optimizer.go | 80 +++++++++------ tidb.go | 23 +++-- tidb_test.go | 1 + 10 files changed, 297 insertions(+), 153 deletions(-) diff --git a/ast/cloner.go b/ast/cloner.go index f08ef09476..bf7b3c961d 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -13,6 +13,8 @@ package ast +import "fmt" + // Cloner is a ast visitor that clones a node. type Cloner struct { } @@ -57,6 +59,9 @@ func cloneStruct(in Node) (out Node) { case *ColumnName: nv := *v out = &nv + case *ColumnNameExpr: + nv := *v + out = &nv case *DefaultExpr: nv := *v out = &nv @@ -162,7 +167,7 @@ func cloneStruct(in Node) (out Node) { default: // We currently only handle expression and select statement. // Will add more when we need to. - panic("unknown ast Node type") + panic("unknown ast Node type " + fmt.Sprintf("%T", v)) } return } diff --git a/ast/dml.go b/ast/dml.go index 0225af0462..b19e60bfc6 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -488,6 +488,7 @@ type UnionStmt struct { Distinct bool Selects []*SelectStmt + OrderBy *OrderByClause Limit *Limit } @@ -505,6 +506,13 @@ func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { } nod.Selects[i] = node.(*SelectStmt) } + if nod.OrderBy != nil { + node, ok := nod.OrderBy.Accept(v) + if !ok { + return nod, false + } + nod.OrderBy = node.(*OrderByClause) + } if nod.Limit != nil { node, ok := nod.Limit.Accept(v) if !ok { diff --git a/ast/expressions.go b/ast/expressions.go index 43fc176498..e32b14752a 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -14,6 +14,7 @@ package ast import ( + "fmt" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" @@ -58,6 +59,8 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Data = value // TODO: make it more precise. switch value.(type) { + case nil: + ve.Type = types.NewFieldType(mysql.TypeNull) case bool, int64: ve.Type = types.NewFieldType(mysql.TypeLonglong) case uint64: @@ -80,7 +83,7 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type.Charset = "binary" ve.Type.Collate = "binary" default: - panic("illegal literal value type") + panic("illegal literal value type:" + fmt.Sprintf("T", value)) } return ve } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 6f096b8687..643195111a 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -773,13 +773,9 @@ ColumnNameListOpt: { $$ = []*ast.ColumnName{} } -| '(' ')' +| ColumnNameList { - $$ = []*ast.ColumnName{} - } -| '(' ColumnNameList ')' - { - $$ = $2.([]*ast.ColumnName) + $$ = $1.([]*ast.ColumnName) } CommitStmt: @@ -1144,12 +1140,13 @@ DeleteFromStmt: LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - Order: $8.([]*ast.ByItem), } if $7 != nil { x.Where = $7.(ast.ExprNode) } - + if $8 != nil { + x.Order = $8.([]*ast.ByItem) + } if $9 != nil { x.Limit = $9.(*ast.Limit) } @@ -1501,13 +1498,13 @@ Field: } | Expression FieldAsNameOpt { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: $2.(model.CIStr)} + $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: model.NewCIStr($2.(string))} } FieldAsNameOpt: /* EMPTY */ { - $$ = model.CIStr{} + $$ = "" } | FieldAsName { @@ -2213,7 +2210,7 @@ TrimDirection: FunctionCallAgg: "AVG" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt ExpressionList ')' { @@ -2550,14 +2547,32 @@ RollbackStmt: } SelectStmt: - "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt + "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt { - $$ = &ast.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), Fields: $3.(*ast.FieldList), - From: nil, - LockTp: $6.(ast.SelectLockType), + LockTp: $5.(ast.SelectLockType), } + if $4 != nil { + st.Limit = $4.(*ast.Limit) + } + $$ = st + } +| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt + { + st := &ast.SelectStmt { + Distinct: $2.(bool), + Fields: $3.(*ast.FieldList), + LockTp: $7.(ast.SelectLockType), + } + if $5 != nil { + st.Where = $5.(ast.ExprNode) + } + if $6 != nil { + st.Limit = $6.(*ast.Limit) + } + $$ = st } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional @@ -2594,8 +2609,7 @@ SelectStmt: } FromDual: - /* Empty */ -| "FROM" "DUAL" + "FROM" "DUAL" TableRefsClause: @@ -2836,13 +2850,7 @@ UnionStmt: { union := $1.(*ast.UnionStmt) if $2 != nil { - // push union order by into select statements. - orderBy := $2.(*ast.OrderByClause) - cloner := &ast.Cloner{} - for _, s := range union.Selects { - node, _ := orderBy.Accept(cloner) - s.OrderBy = node.(*ast.OrderByClause) - } + union.OrderBy = $2.(*ast.OrderByClause) } if $3 != nil { union.Limit = $3.(*ast.Limit) @@ -3052,12 +3060,15 @@ ShowStmt: } | "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowTables, DBName: $4.(string), Full: $2.(bool), - Where: $5.(ast.ExprNode), } + if $5 != nil { + stmt.Where = $5.(ast.ExprNode) + } + $$ = stmt } | "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt { @@ -3074,18 +3085,24 @@ ShowStmt: } | "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowVariables, GlobalScope: $2.(bool), - Where: $4.(ast.ExprNode), } + if $4 != nil { + stmt.Where = $4.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "COLLATION" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowCollation, - Where: $3.(ast.ExprNode), } + if $3 != nil { + stmt.Where = $3.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "CREATE" "TABLE" TableName { @@ -3174,6 +3191,7 @@ Statement: | PreparedStmt | RollbackStmt | SelectStmt +| UnionStmt | SetStmt | ShowStmt | TruncateTableStmt @@ -3807,13 +3825,11 @@ StringName: * See: https://dev.mysql.com/doc/refman/5.7/en/update.html ***********************************************************************************/ UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { - // Single-table syntax - join := &ast.Join{Left: &ast.TableSource{Source:$4.(ast.ResultSetNode)}, Right: nil} st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: join}, + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, List: $6.([]*ast.Assignment), } if $7 != nil { @@ -3830,23 +3846,6 @@ UpdateStmt: break } } -| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional - { - // Multiple-table syntax - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, - List: $6.([]*ast.Assignment), - MultipleTable: true, - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - $$ = st - if yylex.(*lexer).root { - break - } - } UseStmt: "USE" DBName diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index a23a68a72a..cdbe765347 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/stringutil" ) type lexer struct { @@ -46,6 +47,7 @@ type lexer struct { val []byte ungetBuf []byte root bool + prepare bool stmtStartPos int stringLit []byte @@ -86,6 +88,14 @@ func (l *lexer) SetInj(inj int) { l.inj = inj } +func (l *lexer) SetPrepare() { + l.prepare = true +} + +func (l *lexer) IsPrepare() bool { + return l.prepare +} + func (l *lexer) Root() bool { return l.root } @@ -1001,7 +1011,8 @@ func (l *lexer) str(lval *yySymType, pref string) int { s = strings.TrimSuffix(s, "'") + "\"" pref = "\"" } - v, err := strconv.Unquote(pref + s) + v := stringutil.RemoveUselessBackslash(pref+s) + v, err := strconv.Unquote(v) if err != nil { v = strings.TrimSuffix(s, pref) } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 41c78a815e..7a62c1ee96 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -166,14 +166,14 @@ func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { oldSubquery := &subquery.SubQuery{} switch x := v.Query.(type) { case *ast.SelectStmt: - oldSelect, err := convertSelect(x) + oldSelect, err := convertSelect(c, x) if err != nil { c.err = err return } oldSubquery.Stmt = oldSelect case *ast.UnionStmt: - oldUnion, err := convertUnion(x) + oldUnion, err := convertUnion(c, x) if err != nil { c.err = err return diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 036bfebefd..b58f2d240d 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -16,8 +16,10 @@ package optimizer import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/stmt" @@ -37,17 +39,16 @@ func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expr return oldAssign, nil } -func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { oldInsert := &stmts.InsertIntoStmt{ Priority: v.Priority, Text: v.Text(), } - tableName := v.Table.TableRefs.Left.(*ast.TableName) + tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} for _, val := range v.Columns { oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) } - converter := newExpressionConverter() for _, row := range v.Lists { var oldRow []expression.Expression for _, val := range row { @@ -77,9 +78,9 @@ func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { var err error switch x := v.Select.(type) { case *ast.SelectStmt: - oldInsert.Sel, err = convertSelect(x) + oldInsert.Sel, err = convertSelect(converter, x) case *ast.UnionStmt: - oldInsert.Sel, err = convertUnion(x) + oldInsert.Sel, err = convertUnion(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -88,7 +89,7 @@ func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { return oldInsert, nil } -func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { +func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { oldDelete := &stmts.DeleteStmt{ BeforeFrom: v.BeforeFrom, Ignore: v.Ignore, @@ -97,7 +98,6 @@ func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { Quick: v.Quick, Text: v.Text(), } - converter := newExpressionConverter() oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) if err != nil { return nil, errors.Trace(err) @@ -131,14 +131,13 @@ func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { return oldDelete, nil } -func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { +func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { oldUpdate := &stmts.UpdateStmt{ Ignore: v.Ignore, MultipleTable: v.MultipleTable, LowPriority: v.LowPriority, Text: v.Text(), } - converter := newExpressionConverter() var err error oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) if err != nil { @@ -176,10 +175,7 @@ func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { return oldUpdate, nil } -func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { - converter := &expressionConverter{ - exprMap: map[ast.Node]expression.Expression{}, - } +func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldSelect := &stmts.SelectStmt{} oldSelect.Distinct = s.Distinct oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) @@ -187,9 +183,13 @@ func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldField := &field.Field{} oldField.AsName = val.AsName.O var err error - oldField.Expr, err = convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) + if val.Expr != nil { + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + } else if val.WildCard != nil { + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr("*")} } oldSelect.Fields[i] = oldField } @@ -245,7 +245,7 @@ func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { return oldSelect, nil } -func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { +func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { oldUnion := &stmts.UnionStmt{} oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) oldUnion.Distincts = make([]bool, len(u.Selects)-1) @@ -255,12 +255,19 @@ func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { } } for i, val := range u.Selects { - oldSelect, err := convertSelect(val) + oldSelect, err := convertSelect(converter, val) if err != nil { return nil, errors.Trace(err) } oldUnion.Selects[i] = oldSelect } + if u.OrderBy != nil { + oldOrderBy, err := convertOrderBy(converter, u.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.OrderBy = oldOrderBy + } if u.Limit != nil { if u.Limit.Offset > 0 { oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} @@ -298,7 +305,7 @@ func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRse oldJoin.Left = oldLeft } - switch r := join.Left.(type) { + switch r := join.Right.(type) { case *ast.Join: oldRight, err := convertJoin(converter, r) if err != nil { @@ -331,13 +338,13 @@ func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*r case *ast.TableName: oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} case *ast.SelectStmt: - oldSelect, err := convertSelect(src) + oldSelect, err := convertSelect(converter, src) if err != nil { return nil, errors.Trace(err) } oldTs.Source = oldSelect case *ast.UnionStmt: - oldUnion, err := convertUnion(src) + oldUnion, err := convertUnion(converter, src) if err != nil { return nil, errors.Trace(err) } @@ -384,7 +391,7 @@ func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) return oldOrderBy, nil } -func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { +func convertCreateDatabase(converter *expressionConverter, v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { oldCreateDatabase := &stmts.CreateDatabaseStmt{ IfNotExists: v.IfNotExists, Name: v.Name, @@ -404,7 +411,7 @@ func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt return oldCreateDatabase, nil } -func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { +func convertDropDatabase(converter *expressionConverter, v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { return &stmts.DropDatabaseStmt{ IfExists: v.IfExists, Name: v.Name, @@ -512,12 +519,12 @@ func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*cold return oldConstraint, nil } -func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { +func convertCreateTable(converter *expressionConverter, v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { oldCreateTable := &stmts.CreateTableStmt{ - Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + IfNotExists: v.IfNotExists, + Text: v.Text(), } - converter := newExpressionConverter() for _, val := range v.Cols { oldColDef, err := convertColumnDef(converter, val) if err != nil { @@ -551,7 +558,7 @@ func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) return oldCreateTable, nil } -func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { +func convertDropTable(converter *expressionConverter, v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { oldDropTable := &stmts.DropTableStmt{ IfExists: v.IfExists, Text: v.Text(), @@ -566,7 +573,7 @@ func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { return oldDropTable, nil } -func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { +func convertCreateIndex(converter *expressionConverter, v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { oldCreateIndex := &stmts.CreateIndexStmt{ IndexName: v.IndexName, Unique: v.Unique, @@ -587,7 +594,7 @@ func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) return oldCreateIndex, nil } -func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { +func convertDropIndex(converter *expressionConverter, v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { return &stmts.DropIndexStmt{ IfExists: v.IfExists, IndexName: v.IndexName, @@ -595,14 +602,110 @@ func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { }, nil } -func convertAlterTable(v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { +func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec) (*ddl.AlterSpecification, error) { + oldAlterSpec := &ddl.AlterSpecification{ + Name: v.Name, + } + switch v.Tp { + case ast.AlterTableAddConstraint: + oldAlterSpec.Action = ddl.AlterAddConstr + case ast.AlterTableAddColumn: + oldAlterSpec.Action = ddl.AlterAddColumn + case ast.AlterTableDropColumn: + oldAlterSpec.Action = ddl.AlterDropColumn + case ast.AlterTableDropForeignKey: + oldAlterSpec.Action = ddl.AlterDropForeignKey + case ast.AlterTableDropIndex: + oldAlterSpec.Action = ddl.AlterDropIndex + case ast.AlterTableDropPrimaryKey: + oldAlterSpec.Action = ddl.AlterDropPrimaryKey + case ast.AlterTableOption: + oldAlterSpec.Action = ddl.AlterTableOpt + } + if v.Column != nil { + oldColDef, err := convertColumnDef(converter, v.Column) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Column = oldColDef + } + if v.Position != nil { + oldAlterSpec.Position = &ddl.ColumnPosition{} + switch v.Position.Tp { + case ast.ColumnPositionNone: + oldAlterSpec.Position.Type = ddl.ColumnPositionNone + case ast.ColumnPositionFirst: + oldAlterSpec.Position.Type = ddl.ColumnPositionFirst + case ast.ColumnPositionAfter: + oldAlterSpec.Position.Type = ddl.ColumnPositionAfter + } + if v.ColumnName != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.ColumnName) + } + } + if v.Constraint != nil { + oldConstraint, err := convertConstraint(converter, v.Constraint) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Constraint = oldConstraint + } + for _, val := range v.Options { + oldOpt := &coldef.TableOpt{ + StrValue: val.StrValue, + UintValue: val.UintValue, + } + switch val.Tp { + case ast.TableOptionNone: + oldOpt.Tp = coldef.TblOptNone + case ast.TableOptionEngine: + oldOpt.Tp = coldef.TblOptEngine + case ast.TableOptionCharset: + oldOpt.Tp = coldef.TblOptCharset + case ast.TableOptionCollate: + oldOpt.Tp = coldef.TblOptCollate + case ast.TableOptionAutoIncrement: + oldOpt.Tp = coldef.TblOptAutoIncrement + case ast.TableOptionComment: + oldOpt.Tp = coldef.TblOptComment + case ast.TableOptionAvgRowLength: + oldOpt.Tp = coldef.TblOptAvgRowLength + case ast.TableOptionCheckSum: + oldOpt.Tp = coldef.TblOptCheckSum + case ast.TableOptionCompression: + oldOpt.Tp = coldef.TblOptCompression + case ast.TableOptionConnection: + oldOpt.Tp = coldef.TblOptConnection + case ast.TableOptionPassword: + oldOpt.Tp = coldef.TblOptPassword + case ast.TableOptionKeyBlockSize: + oldOpt.Tp = coldef.TblOptKeyBlockSize + case ast.TableOptionMaxRows: + oldOpt.Tp = coldef.TblOptMaxRows + case ast.TableOptionMinRows: + oldOpt.Tp = coldef.TblOptMinRows + } + oldAlterSpec.TableOpts = append(oldAlterSpec.TableOpts, oldOpt) + } + return oldAlterSpec, nil +} + +func convertAlterTable(converter *expressionConverter, v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { oldAlterTable := &stmts.AlterTableStmt{ - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertAlterTableSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterTable.Specs = append(oldAlterTable.Specs, oldSpec) } return oldAlterTable, nil } -func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { +func convertTruncateTable(converter *expressionConverter, v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { return &stmts.TruncateTableStmt{ TableIdent: table.Ident{ Schema: v.Table.Schema, @@ -612,20 +715,20 @@ func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, e }, nil } -func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { +func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { oldExplain := &stmts.ExplainStmt{ Text: v.Text(), } var err error switch x := v.Stmt.(type) { case *ast.SelectStmt: - oldExplain.S, err = convertSelect(x) + oldExplain.S, err = convertSelect(converter, x) case *ast.UpdateStmt: - oldExplain.S, err = convertUpdate(x) + oldExplain.S, err = convertUpdate(converter, x) case *ast.DeleteStmt: - oldExplain.S, err = convertDelete(x) + oldExplain.S, err = convertDelete(converter, x) case *ast.InsertStmt: - oldExplain.S, err = convertInsert(x) + oldExplain.S, err = convertInsert(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -633,7 +736,7 @@ func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { return oldExplain, nil } -func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { +func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { oldPrepare := &stmts.PreparedStmt{ InPrepare: true, Name: v.Name, @@ -651,7 +754,7 @@ func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { return oldPrepare, nil } -func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { +func convertDeallocate(converter *expressionConverter, v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { return &stmts.DeallocateStmt{ ID: v.ID, Name: v.Name, @@ -659,14 +762,13 @@ func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { }, nil } -func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { +func convertExecute(converter *expressionConverter, v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { oldExec := &stmts.ExecuteStmt{ ID: v.ID, Name: v.Name, Text: v.Text(), } oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) - converter := newExpressionConverter() for i, val := range v.UsingVars { oldVar, err := convertExpr(converter, val) if err != nil { @@ -677,7 +779,7 @@ func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { return oldExec, nil } -func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { +func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowStmt, error) { oldShow := &stmts.ShowStmt{ DBName: v.DBName, Flag: v.Flag, @@ -735,25 +837,25 @@ func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { return oldShow, nil } -func convertBegin(v *ast.BeginStmt) (*stmts.BeginStmt, error) { +func convertBegin(converter *expressionConverter, v *ast.BeginStmt) (*stmts.BeginStmt, error) { return &stmts.BeginStmt{ Text: v.Text(), }, nil } -func convertCommit(v *ast.CommitStmt) (*stmts.CommitStmt, error) { +func convertCommit(converter *expressionConverter, v *ast.CommitStmt) (*stmts.CommitStmt, error) { return &stmts.CommitStmt{ Text: v.Text(), }, nil } -func convertRollback(v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { +func convertRollback(converter *expressionConverter, v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { return &stmts.RollbackStmt{ Text: v.Text(), }, nil } -func convertUse(v *ast.UseStmt) (*stmts.UseStmt, error) { +func convertUse(converter *expressionConverter, v *ast.UseStmt) (*stmts.UseStmt, error) { return &stmts.UseStmt{ DBName: v.DBName, Text: v.Text(), @@ -775,12 +877,11 @@ func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAs }, nil } -func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { +func convertSet(converter *expressionConverter, v *ast.SetStmt) (*stmts.SetStmt, error) { oldSet := &stmts.SetStmt{ Text: v.Text(), Variables: make([]*stmts.VariableAssignment, len(v.Variables)), } - converter := newExpressionConverter() for i, val := range v.Variables { oldAssign, err := convertVariableAssignment(converter, val) if err != nil { @@ -791,7 +892,7 @@ func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { return oldSet, nil } -func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { +func convertSetCharset(converter *expressionConverter, v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { return &stmts.SetCharsetStmt{ Charset: v.Charset, Collate: v.Collate, @@ -799,7 +900,7 @@ func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { }, nil } -func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { +func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { return &stmts.SetPwdStmt{ User: v.User, Password: v.Password, @@ -807,14 +908,13 @@ func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { }, nil } -func convertDo(v *ast.DoStmt) (*stmts.DoStmt, error) { - exprConverter := newExpressionConverter() +func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { oldDo := &stmts.DoStmt{ Text: v.Text(), Exprs: make([]expression.Expression, len(v.Exprs)), } for i, val := range v.Exprs { - oldExpr, err := convertExpr(exprConverter, val) + oldExpr, err := convertExpr(converter, val) if err != nil { return nil, errors.Trace(err) } diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index cec754d164..c4d3966cd8 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -16,79 +16,95 @@ package optimizer import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/stmt" ) +type Compiler struct { + paramMarkers []*expression.ParamMarker +} + // Compile compiles a ast.Node into a executable statement. -func Compile(node ast.Node) (stmt.Statement, error) { +func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { validator := &validator{} if _, ok := node.Accept(validator); !ok { return nil, errors.Trace(validator.err) } - binder := &InfoBinder{} - if _, ok := node.Accept(validator); !ok { - return nil, errors.Trace(binder.Err) - } + // binder := &InfoBinder{} + // if _, ok := node.Accept(validator); !ok { + // return nil, errors.Trace(binder.Err) + // } tpComputer := &typeComputer{} if _, ok := node.Accept(tpComputer); !ok { return nil, errors.Trace(tpComputer.err) } - + c := newExpressionConverter() + defer func() { + for _, v := range c.exprMap { + if x, ok := v.(*expression.ParamMarker); ok { + com.paramMarkers = append(com.paramMarkers, x) + } + } + }() switch v := node.(type) { case *ast.InsertStmt: - return convertInsert(v) + return convertInsert(c, v) case *ast.DeleteStmt: - return convertDelete(v) + return convertDelete(c, v) case *ast.UpdateStmt: - return convertUpdate(v) + return convertUpdate(c, v) case *ast.SelectStmt: - return convertSelect(v) + return convertSelect(c, v) case *ast.UnionStmt: - return convertUnion(v) + return convertUnion(c, v) case *ast.CreateDatabaseStmt: - return convertCreateDatabase(v) + return convertCreateDatabase(c, v) case *ast.DropDatabaseStmt: - return convertDropDatabase(v) + return convertDropDatabase(c, v) case *ast.CreateTableStmt: - return convertCreateTable(v) + return convertCreateTable(c, v) case *ast.DropTableStmt: - return convertDropTable(v) + return convertDropTable(c, v) case *ast.CreateIndexStmt: - return convertCreateIndex(v) + return convertCreateIndex(c, v) case *ast.DropIndexStmt: - return convertDropIndex(v) + return convertDropIndex(c, v) case *ast.AlterTableStmt: - return convertAlterTable(v) + return convertAlterTable(c, v) case *ast.TruncateTableStmt: - return convertTruncateTable(v) + return convertTruncateTable(c, v) case *ast.ExplainStmt: - return convertExplain(v) + return convertExplain(c, v) case *ast.PrepareStmt: - return convertPrepare(v) + return convertPrepare(c, v) case *ast.DeallocateStmt: - return convertDeallocate(v) + return convertDeallocate(c, v) case *ast.ExecuteStmt: - return convertExecute(v) + return convertExecute(c, v) case *ast.ShowStmt: - return convertShow(v) + return convertShow(c, v) case *ast.BeginStmt: - return convertBegin(v) + return convertBegin(c, v) case *ast.CommitStmt: - return convertCommit(v) + return convertCommit(c, v) case *ast.RollbackStmt: - return convertRollback(v) + return convertRollback(c, v) case *ast.UseStmt: - return convertUse(v) + return convertUse(c, v) case *ast.SetStmt: - return convertSet(v) + return convertSet(c, v) case *ast.SetCharsetStmt: - return convertSetCharset(v) + return convertSetCharset(c, v) case *ast.SetPwdStmt: - return convertSetPwd(v) + return convertSetPwd(c, v) case *ast.DoStmt: - return convertDo(v) + return convertDo(c, v) } return nil, nil } + +func (com *Compiler) ParamMarkers() []*expression.ParamMarker { + return com.paramMarkers +} diff --git a/tidb.go b/tidb.go index 55581af782..27fcf5167a 100644 --- a/tidb.go +++ b/tidb.go @@ -26,14 +26,13 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/optimizer" - "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" @@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) { rawStmt := l.Stmts() stmts := make([]stmt.Statement, len(rawStmt)) for i, v := range rawStmt { - if node, ok := v.(ast.Node); ok { - stm, err := optimizer.Compile(node) - if err != nil { - return nil, errors.Trace(err) - } - stmts[i] = stm - } else { - stmts[i] = v.(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, err := compiler.Compile(v) + if err != nil { + return nil, errors.Trace(err) } + stmts[i] = stm } return stmts, nil } @@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express return nil, nil, nil } sm := sms[0] - return sm.(stmt.Statement), l.ParamList, nil + compiler := &optimizer.Compiler{} + statement, err := compiler.Compile(sm) + if err != nil { + return nil, nil, errors.Trace(err) + } + return statement, compiler.ParamMarkers(), nil } func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { diff --git a/tidb_test.go b/tidb_test.go index 0cb22ed701..8249ba21aa 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) { // conflict mustExecSQL(c, se1, "begin") rs, err := exec(c, se1, "select * from t where c1=11 for update") + c.Assert(err, IsNil) _, err = rs.Rows(-1, 0) mustExecSQL(c, se2, "begin") From 8f534190160104c093ed891435ad7eea2688d999 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Fri, 30 Oct 2015 14:39:48 +0800 Subject: [PATCH 06/25] optimizer: fix bugs, pass tests. --- ast/ddl.go | 8 +- ast/expressions.go | 22 ++- ast/misc.go | 100 ++++++++++- ast/parser/parser.y | 230 ++++++++++++++++++++---- ast/parser/scanner.l | 6 + ddl/ddl_test.go | 7 +- optimizer/{optimizer.go => compiler.go} | 8 +- optimizer/convert_stmt.go | 103 ++++++++++- stmt/stmts/drop_test.go | 2 +- stmt/stmts/show_test.go | 2 +- 10 files changed, 426 insertions(+), 62 deletions(-) rename optimizer/{optimizer.go => compiler.go} (90%) diff --git a/ast/ddl.go b/ast/ddl.go index 37a5d99c88..87a5ffba70 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -498,7 +498,7 @@ type AlterTableSpec struct { Constraint *Constraint Options []*TableOption Column *ColumnDef - ColumnName *ColumnName + DropColumn *ColumnName Position *ColumnPosition } @@ -523,12 +523,12 @@ func (nod *AlterTableSpec) Accept(v Visitor) (Node, bool) { } nod.Column = node.(*ColumnDef) } - if nod.ColumnName != nil { - node, ok := nod.ColumnName.Accept(v) + if nod.DropColumn != nil { + node, ok := nod.DropColumn.Accept(v) if !ok { return nod, false } - nod.ColumnName = node.(*ColumnName) + nod.DropColumn = node.(*ColumnName) } if nod.Position != nil { node, ok := nod.Position.Accept(v) diff --git a/ast/expressions.go b/ast/expressions.go index e32b14752a..6c09874750 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -83,7 +83,7 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type.Charset = "binary" ve.Type.Collate = "binary" default: - panic("illegal literal value type:" + fmt.Sprintf("T", value)) + panic(fmt.Sprintf("illegal literal value type:%T", value)) } return ve } @@ -581,16 +581,20 @@ func (nod *PatternLikeExpr) Accept(v Visitor) (Node, bool) { return v.Leave(newNod) } nod = newNod.(*PatternLikeExpr) - node, ok := nod.Expr.Accept(v) - if !ok { - return nod, false + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) + if !ok { + return nod, false + } + nod.Expr = node.(ExprNode) } - nod.Expr = node.(ExprNode) - node, ok = nod.Pattern.Accept(v) - if !ok { - return nod, false + if nod.Pattern != nil { + node, ok := nod.Pattern.Accept(v) + if !ok { + return nod, false + } + nod.Pattern = node.(ExprNode) } - nod.Pattern = node.(ExprNode) return v.Leave(nod) } diff --git a/ast/misc.go b/ast/misc.go index d269e27c84..05ff97184e 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -13,6 +13,8 @@ package ast +import "github.com/pingcap/tidb/mysql" + var ( _ StmtNode = &ExplainStmt{} _ StmtNode = &PrepareStmt{} @@ -26,7 +28,9 @@ var ( _ StmtNode = &SetStmt{} _ StmtNode = &SetCharsetStmt{} _ StmtNode = &SetPwdStmt{} + _ StmtNode = &CreateUserStmt{} _ StmtNode = &DoStmt{} + _ StmtNode = &GrantStmt{} _ Node = &VariableAssignment{} ) @@ -165,7 +169,7 @@ const ( // ShowStmt is a statement to provide information about databases, tables, columns and so on. // See: https://dev.mysql.com/doc/refman/5.7/en/show.html type ShowStmt struct { - stmtNode + dmlNode Tp ShowStmtType // Databases/Tables/Columns/.... DBName string @@ -379,10 +383,20 @@ type UserSpec struct { // CreateUserStmt creates user account. // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html type CreateUserStmt struct { + stmtNode + IfNotExists bool Specs []*UserSpec +} - Text string +// Accept implements Node Accept interface. +func (nod *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*CreateUserStmt) + return v.Leave(nod) } // DoStmt is the struct for DO statement. @@ -408,3 +422,85 @@ func (nod *DoStmt) Accept(v Visitor) (Node, bool) { } return v.Leave(nod) } + +// PrivElem is the privilege type and optional column list. +type PrivElem struct { + node + Priv mysql.PrivilegeType + Cols []*ColumnName +} + +// Accept implements Node Accept interface. +func (nod *PrivElem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*PrivElem) + for i, val := range nod.Cols { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Cols[i] = node.(*ColumnName) + } + return v.Leave(nod) +} + +// ObjectTypeType is the type for object type. +type ObjectTypeType int + +const ( + // ObjectTypeNone is for empty object type. + ObjectTypeNone ObjectTypeType = iota + // ObjectTypeTable means the following object is a table. + ObjectTypeTable +) + +// GrantLevelType is the type for grant level. +type GrantLevelType int + +const ( + // GrantLevelNone is the dummy const for default value. + GrantLevelNone GrantLevelType = iota + // GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server. + GrantLevelGlobal + // GrantLevelDB means the privileges apply to all objects in a given database. + GrantLevelDB + // GrantLevelTable means the privileges apply to all columns in a given table. + GrantLevelTable +) + +// GrantLevel is used for store the privilege scope. +type GrantLevel struct { + Level GrantLevelType + DBName string + TableName string +} + +// GrantStmt is the struct for GRANT statement. +type GrantStmt struct { + stmtNode + + Privs []*PrivElem + ObjectType ObjectTypeType + Level *GrantLevel + Users []*UserSpec +} + +// Accept implements Node Accept interface. +func (nod *GrantStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*GrantStmt) + for i, val := range nod.Privs { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Privs[i] = node.(*PrivElem) + } + return v.Leave(nod) +} diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 643195111a..97745a62ad 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -141,6 +141,7 @@ import ( fulltext "FULLTEXT" ge ">=" global "GLOBAL" + grant "GRANT" group "GROUP" groupConcat "GROUP_CONCAT" having "HAVING" @@ -190,6 +191,7 @@ import ( nullIf "NULLIF" offset "OFFSET" on "ON" + option "OPTION" or "OR" order "ORDER" oror "||" @@ -229,6 +231,7 @@ import ( tableKwd "TABLE" tables "TABLES" then "THEN" + to "TO" trailing "TRAILING" transaction "TRANSACTION" trim "TRIM" @@ -403,6 +406,7 @@ import ( FunctionNameConflict "Built-in function call names which are conflict with keywords" FuncDatetimePrec "Function datetime precision" GlobalScope "The scope of variable" + GrantStmt "Grant statement" GroupByClause "GROUP BY clause" HashString "Hashed string" HavingClause "HAVING clause" @@ -430,6 +434,7 @@ import ( NotOpt "optional NOT" NowSym "CURRENT_TIMESTAMP/LOCALTIME/LOCALTIMESTAMP/NOW" NumLiteral "Num/Int/Float/Decimal Literal" + ObjectType "Grant statement object type" OnDuplicateKeyUpdate "ON DUPLICATE KEY UPDATE value list" Operand "operand" OptFull "Full or empty" @@ -448,6 +453,10 @@ import ( PrimaryExpression "primary expression" PrimaryFactor "primary expression factor" Priority "insert statement priority" + PrivElem "Privilege element" + PrivElemList "Privilege element list" + PrivLevel "Privilege scope" + PrivType "Privilege type" ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" RollbackStmt "ROLLBACK statement" @@ -491,7 +500,7 @@ import ( UnionOpt "Union Option(empty/ALL/DISTINCT)" UnionStmt "Union select state ment" UnionClauseList "Union select clause list" - UnionClausePList "Union (select) clause list" + UnionSelect "Union (select) item" UpdateStmt "UPDATE statement" Username "Username" UserSpec "Username and auth option" @@ -633,7 +642,7 @@ AlterTableSpec: { $$ = &ast.AlterTableSpec{ Tp: ast.AlterTableDropColumn, - ColumnName: $3.(*ast.ColumnName), + DropColumn: $3.(*ast.ColumnName), } } | "DROP" "PRIMARY" "KEY" @@ -1253,7 +1262,7 @@ ExplainStmt: { $$ = &ast.ExplainStmt{ Stmt: &ast.ShowStmt{ - Tp: ast.ShowTables, + Tp: ast.ShowColumns, Table: $2.(*ast.TableName), }, } @@ -2842,32 +2851,36 @@ SelectLockOpt: // See: https://dev.mysql.com/doc/refman/5.7/en/union.html UnionStmt: - UnionClauseList - { - $$ = $1.(*ast.UnionStmt) - } -| UnionClausePList OrderByOptional SelectStmtLimit + UnionClauseList "UNION" UnionOpt SelectStmt { union := $1.(*ast.UnionStmt) - if $2 != nil { - union.OrderBy = $2.(*ast.OrderByClause) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union + } +| UnionClauseList "UNION" UnionOpt '(' SelectStmt ')' OrderByOptional SelectStmtLimit + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) + if $7 != nil { + union.OrderBy = $7.(*ast.OrderByClause) } - if $3 != nil { - union.Limit = $3.(*ast.Limit) + if $8 != nil { + union.Limit = $8.(*ast.Limit) } $$ = union } UnionClauseList: - SelectStmt "UNION" UnionOpt SelectStmt + UnionSelect { - selects := []*ast.SelectStmt{$1.(*ast.SelectStmt), $4.(*ast.SelectStmt)} + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt)} $$ = &ast.UnionStmt{ - Distinct: $3.(bool), Selects: selects, } } -| UnionClauseList "UNION" UnionOpt SelectStmt +| UnionClauseList "UNION" UnionOpt UnionSelect { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) @@ -2875,21 +2888,11 @@ UnionClauseList: $$ = union } -UnionClausePList: - '(' SelectStmt ')' "UNION" UnionOpt '(' SelectStmt ')' +UnionSelect: + SelectStmt +| '(' SelectStmt ')' { - selects := []*ast.SelectStmt{$2.(*ast.SelectStmt), $7.(*ast.SelectStmt)} - $$ = &ast.UnionStmt{ - Distinct: $5.(bool), - Selects: selects, - } - } -| UnionClausePList "UNION" UnionOpt '(' SelectStmt ')' - { - union := $1.(*ast.UnionStmt) - union.Distinct = union.Distinct || $3.(bool) - union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) - $$ = union + $$ = $2 } UnionOpt: @@ -3066,7 +3069,11 @@ ShowStmt: Full: $2.(bool), } if $5 != nil { - stmt.Where = $5.(ast.ExprNode) + if x, ok := $5.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { + stmt.Where = $5.(ast.ExprNode) + } } $$ = stmt } @@ -3089,7 +3096,9 @@ ShowStmt: Tp: ast.ShowVariables, GlobalScope: $2.(bool), } - if $4 != nil { + if x, ok := $4.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { stmt.Where = $4.(ast.ExprNode) } $$ = stmt @@ -3099,7 +3108,9 @@ ShowStmt: stmt := &ast.ShowStmt{ Tp: ast.ShowCollation, } - if $3 != nil { + if x, ok := $3.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { stmt.Where = $3.(ast.ExprNode) } $$ = stmt @@ -3187,6 +3198,7 @@ Statement: | DropDatabaseStmt | DropIndexStmt | DropTableStmt +| GrantStmt | InsertIntoStmt | PreparedStmt | RollbackStmt @@ -3895,10 +3907,13 @@ CreateUserStmt: UserSpec: Username AuthOption { - $$ = &ast.UserSpec{ + userSpec := &ast.UserSpec{ User: $1.(string), - AuthOpt: $2.(*ast.AuthOption), } + if $2 != nil { + userSpec.AuthOpt = $2.(*ast.AuthOption) + } + $$ = userSpec } UserSpecList: @@ -3912,7 +3927,9 @@ UserSpecList: } AuthOption: - {} + { + $$ = nil + } | "IDENTIFIED" "BY" AuthString { $$ = &ast.AuthOption { @@ -3922,11 +3939,150 @@ AuthOption: } | "IDENTIFIED" "BY" "PASSWORD" HashString { - $$ = &ast.AuthOption { + $$ = &ast.AuthOption{ HashString: $4.(string), } } HashString: stringLit + +/************************************************************************************* + * Grant statement + * See: https://dev.mysql.com/doc/refman/5.7/en/grant.html + *************************************************************************************/ +GrantStmt: + "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecList + { + $$ = &ast.GrantStmt{ + Privs: $2.([]*ast.PrivElem), + ObjectType: $4.(ast.ObjectTypeType), + Level: $5.(*ast.GrantLevel), + Users: $7.([]*ast.UserSpec), + } + } + +PrivElem: + PrivType + { + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + } + } +| PrivType '(' ColumnNameList ')' + { + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + Cols: $3.([]*ast.ColumnName), + } + } + +PrivElemList: + PrivElem + { + $$ = []*ast.PrivElem{$1.(*ast.PrivElem)} + } +| PrivElemList ',' PrivElem + { + $$ = append($1.([]*ast.PrivElem), $3.(*ast.PrivElem)) + } + +PrivType: + "ALL" + { + $$ = mysql.AllPriv + } +| "ALTER" + { + $$ = mysql.AlterPriv + } +| "CREATE" + { + $$ = mysql.CreatePriv + } +| "CREATE" "USER" + { + $$ = mysql.CreateUserPriv + } +| "DELETE" + { + $$ = mysql.DeletePriv + } +| "DROP" + { + $$ = mysql.DropPriv + } +| "EXECUTE" + { + $$ = mysql.ExecutePriv + } +| "INDEX" + { + $$ = mysql.IndexPriv + } +| "INSERT" + { + $$ = mysql.InsertPriv + } +| "SELECT" + { + $$ = mysql.SelectPriv + } +| "SHOW" "DATABASES" + { + $$ = mysql.ShowDBPriv + } +| "UPDATE" + { + $$ = mysql.UpdatePriv + } +| "GRANT" "OPTION" + { + $$ = mysql.GrantPriv + } + +ObjectType: + { + $$ = ast.ObjectTypeNone + } +| "TABLE" + { + $$ = ast.ObjectTypeTable + } + +PrivLevel: + '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + } + } +| '*' '.' '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelGlobal, + } + } +| Identifier '.' '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + DBName: $1.(string), + } + } +| Identifier '.' Identifier + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, + DBName: $1.(string), + TableName: $3.(string), + } + } +| Identifier + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, + TableName: $1.(string), + } + } %% diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index cdbe765347..21a8680588 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -316,6 +316,7 @@ from {f}{r}{o}{m} full {f}{u}{l}{l} fulltext {f}{u}{l}{l}{t}{e}{x}{t} global {g}{l}{o}{b}{a}{l} +grant {g}{r}{a}{n}{t} group {g}{r}{o}{u}{p} group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} @@ -356,6 +357,7 @@ national {n}{a}{t}{i}{o}{n}{a}{l} not {n}{o}{t} offset {o}{f}{f}{s}{e}{t} on {o}{n} +option {o}{p}{t}{i}{o}{n} or {o}{r} order {o}{r}{d}{e}{r} outer {o}{u}{t}{e}{r} @@ -389,6 +391,7 @@ sysdate {s}{y}{s}{d}{a}{t}{e} table {t}{a}{b}{l}{e} tables {t}{a}{b}{l}{e}{s} then {t}{h}{e}{n} +to {t}{o} trailing {t}{r}{a}{i}{l}{i}{n}{g} transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n} trim {t}{r}{i}{m} @@ -683,6 +686,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {full} lval.item = string(l.val) return full {fulltext} return fulltext +{grant} return grant {group} return group {group_concat} lval.item = string(l.val) return groupConcat @@ -757,6 +761,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {offset} lval.item = string(l.val) return offset {on} return on +{option} return option {order} return order {or} return or {outer} return outer @@ -819,6 +824,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {tables} lval.item = string(l.val) return tables {then} return then +{to} return to {trailing} return trailing {transaction} lval.item = string(l.val) return transaction diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 864567a29f..a1805a0bcd 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -19,12 +19,13 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/optimizer" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/stmt/stmts" @@ -198,5 +199,7 @@ func statement(sql string) stmt.Statement { log.Debug("Compile", sql) lexer := parser.NewLexer(sql) parser.YYParse(lexer) - return lexer.Stmts()[0].(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, _ := compiler.Compile(lexer.Stmts()[0]) + return stm } diff --git a/optimizer/optimizer.go b/optimizer/compiler.go similarity index 90% rename from optimizer/optimizer.go rename to optimizer/compiler.go index c4d3966cd8..bff4623520 100644 --- a/optimizer/optimizer.go +++ b/optimizer/compiler.go @@ -20,11 +20,12 @@ import ( "github.com/pingcap/tidb/stmt" ) +// Compiler compiles ast.Node into an executable statement. type Compiler struct { paramMarkers []*expression.ParamMarker } -// Compile compiles a ast.Node into a executable statement. +// Compile compiles a ast.Node into an executable statement. func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { validator := &validator{} if _, ok := node.Accept(validator); !ok { @@ -99,12 +100,17 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { return convertSetCharset(c, v) case *ast.SetPwdStmt: return convertSetPwd(c, v) + case *ast.CreateUserStmt: + return convertCreateUser(c, v) case *ast.DoStmt: return convertDo(c, v) + case *ast.GrantStmt: + return convertGrant(c, v) } return nil, nil } +// ParamMarkers returns parameter markers for prepared statement. func (com *Compiler) ParamMarkers() []*expression.ParamMarker { return com.paramMarkers } diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index b58f2d240d..3be497443d 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -176,8 +176,10 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up } func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { - oldSelect := &stmts.SelectStmt{} - oldSelect.Distinct = s.Distinct + oldSelect := &stmts.SelectStmt{ + Distinct: s.Distinct, + Text: s.Text(), + } oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) for i, val := range s.Fields.Fields { oldField := &field.Field{} @@ -246,7 +248,9 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se } func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { - oldUnion := &stmts.UnionStmt{} + oldUnion := &stmts.UnionStmt{ + Text: u.Text(), + } oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) oldUnion.Distincts = make([]bool, len(u.Selects)-1) if u.Distinct { @@ -639,10 +643,13 @@ func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec case ast.ColumnPositionAfter: oldAlterSpec.Position.Type = ddl.ColumnPositionAfter } - if v.ColumnName != nil { - oldAlterSpec.Position.RelativeColumn = joinColumnName(v.ColumnName) + if v.Position.RelativeColumn != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.Position.RelativeColumn) } } + if v.DropColumn != nil { + oldAlterSpec.Name = joinColumnName(v.DropColumn) + } if v.Constraint != nil { oldConstraint, err := convertConstraint(converter, v.Constraint) if err != nil { @@ -729,6 +736,8 @@ func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts. oldExplain.S, err = convertDelete(converter, x) case *ast.InsertStmt: oldExplain.S, err = convertInsert(converter, x) + case *ast.ShowStmt: + oldExplain.S, err = convertShow(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -908,6 +917,36 @@ func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.Se }, nil } +func convertUserSpec(converter *expressionConverter, v *ast.UserSpec) (*coldef.UserSpecification, error) { + oldSpec := &coldef.UserSpecification{ + User: v.User, + } + if v.AuthOpt != nil { + oldAuthOpt := &coldef.AuthOption{ + AuthString: v.AuthOpt.AuthString, + ByAuthString: v.AuthOpt.ByAuthString, + HashString: v.AuthOpt.HashString, + } + oldSpec.AuthOpt = oldAuthOpt + } + return oldSpec, nil +} + +func convertCreateUser(converter *expressionConverter, v *ast.CreateUserStmt) (*stmts.CreateUserStmt, error) { + oldCreateUser := &stmts.CreateUserStmt{ + IfNotExists: v.IfNotExists, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateUser.Specs = append(oldCreateUser.Specs, oldSpec) + } + return oldCreateUser, nil +} + func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { oldDo := &stmts.DoStmt{ Text: v.Text(), @@ -922,3 +961,57 @@ func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, er } return oldDo, nil } + +func convertPrivElem(converter *expressionConverter, v *ast.PrivElem) (*coldef.PrivElem, error) { + oldPrim := &coldef.PrivElem{ + Priv: v.Priv, + } + for _, val := range v.Cols { + oldPrim.Cols = append(oldPrim.Cols, joinColumnName(val)) + } + return oldPrim, nil +} + +func convertGrant(converter *expressionConverter, v *ast.GrantStmt) (*stmts.GrantStmt, error) { + oldGrant := &stmts.GrantStmt{ + Text: v.Text(), + } + for _, val := range v.Privs { + oldPrim, err := convertPrivElem(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Privs = append(oldGrant.Privs, oldPrim) + } + switch v.ObjectType { + case ast.ObjectTypeNone: + oldGrant.ObjectType = coldef.ObjectTypeNone + case ast.ObjectTypeTable: + oldGrant.ObjectType = coldef.ObjectTypeTable + } + if v.Level != nil { + oldGrantLevel := &coldef.GrantLevel{ + DBName: v.Level.DBName, + TableName: v.Level.TableName, + } + switch v.Level.Level { + case ast.GrantLevelDB: + oldGrantLevel.Level = coldef.GrantLevelDB + case ast.GrantLevelGlobal: + oldGrantLevel.Level = coldef.GrantLevelGlobal + case ast.GrantLevelNone: + oldGrantLevel.Level = coldef.GrantLevelNone + case ast.GrantLevelTable: + oldGrantLevel.Level = coldef.GrantLevelTable + } + oldGrant.Level = oldGrantLevel + } + for _, val := range v.Users { + oldUserSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Users = append(oldGrant.Users, oldUserSpec) + } + return oldGrant, nil +} diff --git a/stmt/stmts/drop_test.go b/stmt/stmts/drop_test.go index 5d0a34ec0d..4a1de810c2 100644 --- a/stmt/stmts/drop_test.go +++ b/stmt/stmts/drop_test.go @@ -62,7 +62,7 @@ func (s *testStmtSuite) TestDropTable(c *C) { } func (s *testStmtSuite) TestDropIndex(c *C) { - testSQL := "drop index if exists drop_index;" + testSQL := "drop index if exists drop_index on t;" stmtList, err := tidb.Compile(s.ctx, testSQL) c.Assert(err, IsNil) diff --git a/stmt/stmts/show_test.go b/stmt/stmts/show_test.go index 5dc0c57828..e67813a46d 100644 --- a/stmt/stmts/show_test.go +++ b/stmt/stmts/show_test.go @@ -51,5 +51,5 @@ func (s *testStmtSuite) TestShow(c *C) { c.Assert(stmtList, HasLen, 1) testStmt, ok = stmtList[0].(*stmts.ShowStmt) c.Assert(ok, IsTrue) - c.Assert(testStmt.Pattern, NotNil) + c.Assert(testStmt.Pattern, NotNil, Commentf("S: %s", testStmt.Text)) } From 05b7eefa03a0b10315a2a1205c77439270fd0b58 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Fri, 30 Oct 2015 20:51:36 +0800 Subject: [PATCH 07/25] ast: update after rebase. --- ast/dml.go | 1 + ast/functions.go | 10 +++-- ast/misc.go | 2 + ast/parser/parser.y | 80 ++++++++++++++++++++++++++++++++++----- ast/parser/scanner.l | 19 ++++++++++ optimizer/convert_stmt.go | 55 ++++++++++++++++----------- 6 files changed, 130 insertions(+), 37 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index b19e60bfc6..afd836ef66 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -566,6 +566,7 @@ const ( type InsertStmt struct { dmlNode + Replace bool Table *TableRefsClause Columns []*ColumnName Lists [][]ExprNode diff --git a/ast/functions.go b/ast/functions.go index 97b9cd1e3c..78ab444236 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -200,11 +200,13 @@ func (nod *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { return nod, false } nod.Pos = node.(ExprNode) - node, ok = nod.Len.Accept(v) - if !ok { - return nod, false + if nod.Len != nil { + node, ok = nod.Len.Accept(v) + if !ok { + return nod, false + } + nod.Len = node.(ExprNode) } - nod.Len = node.(ExprNode) return v.Leave(nod) } diff --git a/ast/misc.go b/ast/misc.go index 05ff97184e..b0f9e845fc 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -164,6 +164,7 @@ const ( ShowVariables ShowCollation ShowCreateTable + ShowGrants ) // ShowStmt is a statement to provide information about databases, tables, columns and so on. @@ -177,6 +178,7 @@ type ShowStmt struct { Column *ColumnName // Used for `desc table column`. Flag int // Some flag parsed from sql, such as FULL. Full bool + User string // Used for show grants. // Used by show variables GlobalScope bool diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 97745a62ad..9c02a8e4c5 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -142,6 +142,7 @@ import ( ge ">=" global "GLOBAL" grant "GRANT" + grants "GRANTS" group "GROUP" groupConcat "GROUP_CONCAT" having "HAVING" @@ -206,6 +207,7 @@ import ( references "REFERENCES" regexp "REGEXP" repeat "REPEAT" + replace "REPLACE" right "RIGHT" rlike "RLIKE" rollback "ROLLBACK" @@ -418,7 +420,7 @@ import ( IndexName "index name" IndexType "index type" InsertIntoStmt "INSERT INTO statement" - InsertRest "Rest part of INSERT INTO statement" + InsertValues "Rest part of INSERT/REPLACE INTO statement" IntoOpt "INTO or EmptyString" JoinTable "join table" JoinType "join type" @@ -459,6 +461,8 @@ import ( PrivType "Privilege type" ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" + ReplaceIntoStmt "REPLACE INTO statement" + ReplacePriority "replace statement priority" RollbackStmt "ROLLBACK statement" SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," SelectStmt "SELECT statement" @@ -1224,14 +1228,14 @@ DropIndexStmt: } DropTableStmt: - "DROP" "TABLE" TableNameList + "DROP" TableOrTables TableNameList { $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} if yylex.(*lexer).root { break } } -| "DROP" "TABLE" "IF" "EXISTS" TableNameList +| "DROP" TableOrTables "IF" "EXISTS" TableNameList { $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} if yylex.(*lexer).root { @@ -1239,6 +1243,10 @@ DropTableStmt: } } +TableOrTables: + "TABLE" +| "TABLES" + EqOpt: { } @@ -1619,7 +1627,7 @@ UnReservedKeyword: | "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" | "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" | "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" +| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" NotKeywordToken: "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" @@ -1633,7 +1641,7 @@ NotKeywordToken: * TODO: support PARTITION **********************************************************************************/ InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableName InsertRest OnDuplicateKeyUpdate + "INSERT" Priority IgnoreOptional IntoOpt TableName InsertValues OnDuplicateKeyUpdate { x := $6.(*ast.InsertStmt) x.Priority = $2.(int) @@ -1656,7 +1664,7 @@ IntoOpt: { } -InsertRest: +InsertValues: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { $$ = &ast.InsertStmt{ @@ -1741,7 +1749,39 @@ OnDuplicateKeyUpdate: $$ = $5 } -/***********************************Insert Statments END************************************/ +/***********************************Insert Statements END************************************/ + +/************************************************************************************ + * Replace Statements + * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html + * + * TODO: support PARTITION + **********************************************************************************/ +ReplaceIntoStmt: + "REPLACE" ReplacePriority IntoOpt TableName InsertValues + { + x := $5.(*ast.InsertStmt) + x.Replace = true + x.Priority = $2.(int) + ts := &ast.TableSource{Source: $4.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} + $$ = x + } + +ReplacePriority: + { + $$ = ast.NoPriority + } +| "LOW_PRIORITY" + { + $$ = ast.LowPriority + } +| "DELAYED" + { + $$ = ast.DelayedPriority + } + +/***********************************Replace Statments END************************************/ Literal: "false" @@ -1984,7 +2024,7 @@ FunctionCallKeyword: } | "YEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: @@ -2049,7 +2089,7 @@ FunctionCallNonKeyword: } | "IFNULL" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { @@ -2107,6 +2147,11 @@ FunctionCallNonKeyword: } $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} } +| "REPLACE" '(' Expression ',' Expression ',' Expression ')' + { + args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + } | "SECOND" '(' Expression ')' { $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} @@ -3118,10 +3163,23 @@ ShowStmt: | "SHOW" "CREATE" "TABLE" TableName { $$ = &ast.ShowStmt{ - Tp: ast.ShowCreateTable, + Tp: ast.ShowCreateTable, Table: $4.(*ast.TableName), } } +| "SHOW" "GRANTS" + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &ast.ShowStmt{Tp: ast.ShowGrants} + } +| "SHOW" "GRANTS" "FOR" Username + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &ast.ShowStmt{ + Tp: ast.ShowGrants, + User: $4.(string), + } + } ShowLikeOrWhereOpt: { @@ -3202,6 +3260,7 @@ Statement: | InsertIntoStmt | PreparedStmt | RollbackStmt +| ReplaceIntoStmt | SelectStmt | UnionStmt | SetStmt @@ -3221,6 +3280,7 @@ ExplainableStmt: | DeleteFromStmt | UpdateStmt | InsertIntoStmt +| ReplaceIntoStmt StatementList: Statement diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index 21a8680588..048817d1b9 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -54,6 +54,10 @@ type lexer struct { // record token's offset of the input tokenEndOffset int tokenStartOffset int + + // Charset information + charset string + collation string } @@ -104,6 +108,15 @@ func (l *lexer) SetRoot(root bool) { l.root = root } +func (l *lexer) SetCharsetInfo(charset, collation string) { + l.charset = charset + l.collation = collation +} + +func (l *lexer) GetCharsetInfo() (string, string) { + return l.charset, l.collation +} + func (l *lexer) unget(b byte) { l.ungetBuf = append(l.ungetBuf, b) l.i-- @@ -317,6 +330,7 @@ full {f}{u}{l}{l} fulltext {f}{u}{l}{l}{t}{e}{x}{t} global {g}{l}{o}{b}{a}{l} grant {g}{r}{a}{n}{t} +grants {g}{r}{a}{n}{t}{s} group {g}{r}{o}{u}{p} group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} @@ -370,6 +384,7 @@ rand {r}{a}{n}{d} repeat {r}{e}{p}{e}{a}{t} references {r}{e}{f}{e}{r}{e}{n}{c}{e}{s} regexp {r}{e}{g}{e}{x}{p} +replace {r}{e}{p}{l}{a}{c}{e} right {r}{i}{g}{h}{t} rlike {r}{l}{i}{k}{e} rollback {r}{o}{l}{l}{b}{a}{c}{k} @@ -687,6 +702,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return full {fulltext} return fulltext {grant} return grant +{grants} lval.item = string(l.val) + return grants {group} return group {group_concat} lval.item = string(l.val) return groupConcat @@ -795,6 +812,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {repeat} lval.item = string(l.val) return repeat {regexp} return regexp +{replace} lval.item = string(l.val) + return replace {references} return references {rlike} return rlike diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 3be497443d..ff026f6591 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -39,15 +39,13 @@ func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expr return oldAssign, nil } -func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { - oldInsert := &stmts.InsertIntoStmt{ - Priority: v.Priority, - Text: v.Text(), - } +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (stmt.Statement, error) { + insertValues := stmts.InsertValues{} + insertValues.Priority = v.Priority tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) - oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + insertValues.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} for _, val := range v.Columns { - oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) + insertValues.ColNames = append(insertValues.ColNames, joinColumnName(val)) } for _, row := range v.Lists { var oldRow []expression.Expression @@ -58,33 +56,44 @@ func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.In } oldRow = append(oldRow, oldExpr) } - oldInsert.Lists = append(oldInsert.Lists, oldRow) + insertValues.Lists = append(insertValues.Lists, oldRow) } for _, assign := range v.Setlist { oldAssign, err := convertAssignment(converter, assign) if err != nil { return nil, errors.Trace(err) } - oldInsert.Setlist = append(oldInsert.Setlist, oldAssign) + insertValues.Setlist = append(insertValues.Setlist, oldAssign) + } + + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + insertValues.Sel, err = convertSelect(converter, x) + case *ast.UnionStmt: + insertValues.Sel, err = convertUnion(converter, x) + } + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Replace { + return &stmts.ReplaceIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), + }, nil + } + oldInsert := &stmts.InsertIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), } for _, onDup := range v.OnDuplicate { oldOnDup, err := convertAssignment(converter, onDup) if err != nil { return nil, errors.Trace(err) } - oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, *oldOnDup) - } - if v.Select != nil { - var err error - switch x := v.Select.(type) { - case *ast.SelectStmt: - oldInsert.Sel, err = convertSelect(converter, x) - case *ast.UnionStmt: - oldInsert.Sel, err = convertUnion(converter, x) - } - if err != nil { - return nil, errors.Trace(err) - } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, oldOnDup) } return oldInsert, nil } @@ -155,7 +164,7 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up if err != nil { return nil, errors.Trace(err) } - oldUpdate.List = append(oldUpdate.List, *oldAssign) + oldUpdate.List = append(oldUpdate.List, oldAssign) } if v.Order != nil { orderRset := &rsets.OrderByRset{} From 92db5158c489d60608bd057ad6b065f79fc90e68 Mon Sep 17 00:00:00 2001 From: ngaut Date: Sat, 31 Oct 2015 13:18:38 +0800 Subject: [PATCH 08/25] *: Tiny refactor --- ast/cloner.go | 2 +- ast/functions.go | 4 ++-- optimizer/convert_expr.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ast/cloner.go b/ast/cloner.go index bf7b3c961d..be5d7e441e 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -15,7 +15,7 @@ package ast import "fmt" -// Cloner is a ast visitor that clones a node. +// Cloner is an ast visitor that clones a node. type Cloner struct { } diff --git a/ast/functions.go b/ast/functions.go index 78ab444236..56bf663eeb 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -35,7 +35,7 @@ var ( type FuncCallExpr struct { funcNode // F is the function name. - F string + FnName string // Args is the function args. Args []ExprNode } @@ -59,7 +59,7 @@ func (nod *FuncCallExpr) Accept(v Visitor) (Node, bool) { // IsStatic implements the ExprNode IsStatic interface. func (nod *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(nod.F)] + v := builtin.Funcs[strings.ToLower(nod.FnName)] if v.F == nil || !v.IsStatic { return false } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 7a62c1ee96..ac57e91e77 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -318,7 +318,7 @@ func (c *expressionConverter) variable(v *ast.VariableExpr) { func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) { oldCall := &expression.Call{ - F: v.F, + F: v.FnName, } oldCall.Args = make([]expression.Expression, len(v.Args)) for i, val := range v.Args { From 1b9fbda6b6192de1a4427b1fd0ae6f200f74616d Mon Sep 17 00:00:00 2001 From: ngaut Date: Sat, 31 Oct 2015 13:22:56 +0800 Subject: [PATCH 09/25] ast/parser: Fix parser.y --- ast/parser/parser.y | 70 ++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 9c02a8e4c5..4716d418ff 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1938,16 +1938,16 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_DATE" { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } DistinctOpt: @@ -2007,11 +2007,11 @@ FunctionCallKeyword: } | "DATE" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "USER" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { @@ -2020,21 +2020,21 @@ FunctionCallKeyword: } | "WEEK" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "YEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: "COALESCE" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURDATE" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_TIMESTAMP" FuncDatetimePrec { @@ -2042,35 +2042,35 @@ FunctionCallNonKeyword: if $2 != nil { args = append(args, $2.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "ABS" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "CONCAT" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CONCAT_WS" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "DAY" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFWEEK" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFMONTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFYEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { @@ -2081,19 +2081,19 @@ FunctionCallNonKeyword: } | "FOUND_ROWS" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "HOUR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "IFNULL" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "LOCATE" '(' Expression ',' Expression ')' { @@ -2112,19 +2112,19 @@ FunctionCallNonKeyword: } | "LOWER" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MICROSECOND" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MINUTE" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MONTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "NOW" '(' ExpressionOpt ')' { @@ -2132,11 +2132,11 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "NULLIF" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "RAND" '(' ExpressionOpt ')' { @@ -2145,16 +2145,16 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "REPLACE" '(' Expression ',' Expression ',' Expression ')' { args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "SECOND" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "SUBSTRING" '(' Expression ',' Expression ')' { @@ -2200,7 +2200,7 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "TRIM" '(' Expression ')' { @@ -2232,19 +2232,19 @@ FunctionCallNonKeyword: } | "UPPER" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKDAY" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKOFYEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "YEARWEEK" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } TrimDirection: From 01dac873baa102c4a8ae9a99b7253d03f9044c90 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Sat, 31 Oct 2015 22:37:27 +0800 Subject: [PATCH 10/25] *: Address ci problem --- ast/expressions.go | 4 +++- ast/parser/parser.y | 18 ++++++++++++++++++ ast/parser/scanner.l | 19 ++++++++++++++++++- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index 6c09874750..7c82ed2d2a 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -56,7 +56,7 @@ type ValueExpr struct { // NewValueExpr creates a ValueExpr with value, and sets default field type. func NewValueExpr(value interface{}) *ValueExpr { ve := &ValueExpr{} - ve.Data = value + ve.Data = types.RawData(value) // TODO: make it more precise. switch value.(type) { case nil: @@ -82,6 +82,8 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = "binary" ve.Type.Collate = "binary" + case *types.DataItem: + ve.Type = value.(*types.DataItem).Type default: panic(fmt.Sprintf("illegal literal value type:%T", value)) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 9c02a8e4c5..127a3b965d 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -239,6 +239,7 @@ import ( trim "TRIM" trueKwd "true" truncate "TRUNCATE" + underscoreCS "UNDERSCORE_CHARSET" unknown "UNKNOWN" union "UNION" unique "UNIQUE" @@ -1796,6 +1797,23 @@ Literal: | floatLit | intLit | stringLit +| "UNDERSCORE_CHARSET" stringLit + { + // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + tp := types.NewFieldType(mysql.TypeString) + tp.Charset = $1.(string) + co, err := charset.GetDefaultCollation(tp.Charset) + if err != nil { + l := yylex.(*lexer) + l.errf("Get collation error for charset: %s", tp.Charset) + return 1 + } + tp.Collate = co + $$ = &types.DataItem{ + Type: tp, + Data: $2.(string), + } + } | hexLit | bitLit diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index 048817d1b9..adaa662651 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/stringutil" ) @@ -1010,7 +1011,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return integerType {ident} lval.item = string(l.val) - return identifier + return l.handleIdent(lval) . return c0 @@ -1101,3 +1102,19 @@ func (l *lexer) bit(lval *yySymType) int { lval.item = b return bitLit } + +func (l *lexer) handleIdent(lval *yySymType) int { + s := lval.item.(string) + // A character string literal may have an optional character set introducer and COLLATE clause: + // [_charset_name]'string' [COLLATE collation_name] + // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + if !strings.HasPrefix(s, "_") { + return identifier + } + cs, _, err := charset.GetCharsetInfo(s[1:]) + if err != nil { + return identifier + } + lval.item = cs + return underscoreCS +} From f3016ff56c1371cf8e295b4246a9371c68af2004 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Sat, 31 Oct 2015 23:07:31 +0800 Subject: [PATCH 11/25] optimizer: fix parameter marker out of range --- optimizer/compiler.go | 6 +----- optimizer/convert_expr.go | 9 ++++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/optimizer/compiler.go b/optimizer/compiler.go index bff4623520..1fc875a206 100644 --- a/optimizer/compiler.go +++ b/optimizer/compiler.go @@ -43,11 +43,7 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { } c := newExpressionConverter() defer func() { - for _, v := range c.exprMap { - if x, ok := v.(*expression.ParamMarker); ok { - com.paramMarkers = append(com.paramMarkers, x) - } - } + com.paramMarkers = c.paramMarkers }() switch v := node.(type) { case *ast.InsertStmt: diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index ac57e91e77..52dfbe8859 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -33,8 +33,9 @@ func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression. // expressionConverter converts ast expression to // old expression for transition state. type expressionConverter struct { - exprMap map[ast.Node]expression.Expression - err error + exprMap map[ast.Node]expression.Expression + paramMarkers []*expression.ParamMarker + err error } func newExpressionConverter() *expressionConverter { @@ -265,7 +266,9 @@ func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { } func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { - c.exprMap[v] = &expression.ParamMarker{} + marker := &expression.ParamMarker{} + c.exprMap[v] = marker + c.paramMarkers = append(c.paramMarkers, marker) } func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { From eed9647417dd0148c7b648faa1675bda8cf06939 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 2 Nov 2015 11:34:13 +0800 Subject: [PATCH 12/25] ast, optimizer: add Offset for ParmMarker to make sure the order of param markers. --- ast/expressions.go | 1 + ast/parser/parser.y | 4 +++- optimizer/compiler.go | 30 +++++++++++++++++++++++++----- optimizer/convert_expr.go | 7 +++---- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index 7c82ed2d2a..fe0d311b64 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -609,6 +609,7 @@ func (nod *PatternLikeExpr) IsStatic() bool { // Used in parsing prepare statement. type ParamMarkerExpr struct { exprNode + Offset int } // Accept implements Node Accept interface. diff --git a/ast/parser/parser.y b/ast/parser/parser.y index a32c488fdf..33a4f9f1be 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1844,7 +1844,9 @@ Operand: } | "PLACEHOLDER" { - $$ = &ast.ParamMarkerExpr{} + $$ = &ast.ParamMarkerExpr{ + Offset: yyS[yypt].offset, + } } | "ROW" '(' Expression ',' ExpressionList ')' { diff --git a/optimizer/compiler.go b/optimizer/compiler.go index 1fc875a206..a9e42cd8b0 100644 --- a/optimizer/compiler.go +++ b/optimizer/compiler.go @@ -14,6 +14,8 @@ package optimizer import ( + "sort" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" @@ -22,7 +24,7 @@ import ( // Compiler compiles ast.Node into an executable statement. type Compiler struct { - paramMarkers []*expression.ParamMarker + converter *expressionConverter } // Compile compiles a ast.Node into an executable statement. @@ -42,9 +44,7 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { return nil, errors.Trace(tpComputer.err) } c := newExpressionConverter() - defer func() { - com.paramMarkers = c.paramMarkers - }() + com.converter = c switch v := node.(type) { case *ast.InsertStmt: return convertInsert(c, v) @@ -106,7 +106,27 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { return nil, nil } +type paramMarkers []*ast.ParamMarkerExpr + +func (p paramMarkers) Len() int { + return len(p) +} + +func (p paramMarkers) Less(i, j int) bool { + return p[i].Offset < p[j].Offset +} + +func (p paramMarkers) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + // ParamMarkers returns parameter markers for prepared statement. func (com *Compiler) ParamMarkers() []*expression.ParamMarker { - return com.paramMarkers + c := com.converter + sort.Sort(c.paramMarkers) + oldMarkers := make([]*expression.ParamMarker, len(c.paramMarkers)) + for i, val := range c.paramMarkers { + oldMarkers[i] = c.exprMap[val].(*expression.ParamMarker) + } + return oldMarkers } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 52dfbe8859..e8aa1bce84 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -34,7 +34,7 @@ func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression. // old expression for transition state. type expressionConverter struct { exprMap map[ast.Node]expression.Expression - paramMarkers []*expression.ParamMarker + paramMarkers paramMarkers err error } @@ -266,9 +266,8 @@ func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { } func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { - marker := &expression.ParamMarker{} - c.exprMap[v] = marker - c.paramMarkers = append(c.paramMarkers, marker) + c.paramMarkers = append(c.paramMarkers, v) + c.exprMap[v] = &expression.ParamMarker{} } func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { From ab2ec90d93a3468291c21d057fd9076414764eb3 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 2 Nov 2015 13:34:04 +0800 Subject: [PATCH 13/25] ast/parser, optimizer: fix duplicate parameter marker, fix offset. --- ast/parser/parser.y | 2 +- optimizer/convert_expr.go | 6 ++++-- optimizer/convert_stmt.go | 3 --- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 33a4f9f1be..90de8fad3c 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -2833,7 +2833,7 @@ SelectStmtLimit: } | "LIMIT" LengthNum "OFFSET" LengthNum { - $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} + $$ = &ast.Limit{Offset: $4.(uint64), Count: $2.(uint64)} } SelectStmtDistinct: diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index e8aa1bce84..d43f93f016 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -266,8 +266,10 @@ func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { } func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { - c.paramMarkers = append(c.paramMarkers, v) - c.exprMap[v] = &expression.ParamMarker{} + if c.exprMap[v] == nil { + c.exprMap[v] = &expression.ParamMarker{} + c.paramMarkers = append(c.paramMarkers, v) + } } func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index ff026f6591..535dfcddea 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -762,7 +762,6 @@ func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts. Text: v.Text(), } if v.SQLVar != nil { - converter := newExpressionConverter() oldSQLVar, err := convertExpr(converter, v.SQLVar) if err != nil { return nil, errors.Trace(err) @@ -815,7 +814,6 @@ func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowSt oldShow.ColumnName = joinColumnName(v.Column) } if v.Where != nil { - converter := newExpressionConverter() oldWhere, err := convertExpr(converter, v.Where) if err != nil { return nil, errors.Trace(err) @@ -823,7 +821,6 @@ func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowSt oldShow.Where = oldWhere } if v.Pattern != nil { - converter := newExpressionConverter() oldPattern, err := convertExpr(converter, v.Pattern) if err != nil { return nil, errors.Trace(err) From 067df3ab3438ef0aaa42097743b808884751ebd5 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 2 Nov 2015 18:37:49 +0800 Subject: [PATCH 14/25] optimizer: fix wildcard conversion. --- ast/expressions.go | 2 +- ast/functions.go | 3 +++ ast/parser/parser.y | 5 +++-- optimizer/convert_stmt.go | 9 ++++++++- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index fe0d311b64..cddb7f45d5 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -66,7 +66,7 @@ func NewValueExpr(value interface{}) *ValueExpr { case uint64: ve.Type = types.NewFieldType(mysql.TypeLonglong) ve.Type.Flag |= mysql.UnsignedFlag - case string: + case string, UnquoteString: ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = mysql.DefaultCharset ve.Type.Collate = mysql.DefaultCollationName diff --git a/ast/functions.go b/ast/functions.go index 56bf663eeb..e3323ef3d1 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -31,6 +31,9 @@ var ( _ FuncNode = &AggregateFuncExpr{} ) +// UnquoteString is not quoted when printed. +type UnquoteString string + // FuncCallExpr is for function expression. type FuncCallExpr struct { funcNode diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 90de8fad3c..a35ea0d086 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -2036,7 +2036,7 @@ FunctionCallKeyword: | "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { // TODO: support qualified identifier for column_name - $$ = &ast.ColumnNameExpr{Name: $3.(*ast.ColumnName)} + $$ = &ast.ValuesExpr{Column: $3.(*ast.ColumnName)} } | "WEEK" '(' ExpressionList ')' { @@ -2292,7 +2292,7 @@ FunctionCallAgg: } | "COUNT" '(' DistinctOpt '*' ')' { - args := []ast.ExprNode{ast.NewValueExpr("*")} + args := []ast.ExprNode{ast.NewValueExpr(ast.UnquoteString("*"))} $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' @@ -2753,6 +2753,7 @@ TableFactor: TableAsNameOpt: { $$ = model.CIStr{} + yyS[yypt].offset = yylex.(*lexer).i } | TableAsName { diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 535dfcddea..528e360754 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -200,7 +200,14 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se return nil, errors.Trace(err) } } else if val.WildCard != nil { - oldField.Expr = &expression.Ident{CIStr: model.NewCIStr("*")} + str := "*" + if val.WildCard.Table.O != "" { + str = val.WildCard.Table.O + ".*" + if val.WildCard.Schema.O != "" { + str = val.WildCard.Schema.O + "." + str + } + } + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr(str)} } oldSelect.Fields[i] = oldField } From ba9dfc3551e2b4ddf6afb6cbc0c8b9421082bb2e Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 10:00:17 +0800 Subject: [PATCH 15/25] ast/parser: add text to select field. --- ast/ast.go | 2 ++ ast/dml.go | 2 ++ ast/parser/parser.y | 64 +++++++++++++++++++++++++++++++++++---- ast/parser/scanner.l | 10 ++++++ optimizer/convert_stmt.go | 3 ++ 5 files changed, 75 insertions(+), 6 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index fc505c8362..6b699d796e 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -97,10 +97,12 @@ type ResultSetNode interface { // Visitor visits a Node. type Visitor interface { // VisitEnter is called before children nodes is visited. + // The returned node should be the same type as the input node n. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. + // The returned node should be the same type as the input node n. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) } diff --git a/ast/dml.go b/ast/dml.go index afd836ef66..b45fc6374d 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -203,6 +203,8 @@ func (nod *WildCardField) Accept(v Visitor) (Node, bool) { type SelectField struct { node + // Offset is used to get original text. + Offset int // If WildCard is not nil, Expr will be nil. WildCard *WildCardField // If Expr is not nil, WildCard will be nil. diff --git a/ast/parser/parser.y b/ast/parser/parser.y index a35ea0d086..1d6b05e314 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1516,7 +1516,15 @@ Field: } | Expression FieldAsNameOpt { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: model.NewCIStr($2.(string))} + expr := $1.(ast.ExprNode) + asName := $2.(string) + if asName != "" { + // Set expr original text. + offset := yyS[yypt-1].offset + end := yyS[yypt].offset-1 + expr.SetText(yylex.(*lexer).src[offset:end]) + } + $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} } FieldAsNameOpt: @@ -1550,11 +1558,22 @@ FieldAsName: FieldList: Field { - $$ = []*ast.SelectField{$1.(*ast.SelectField)} + field := $1.(*ast.SelectField) + field.Offset = yyS[yypt].offset + $$ = []*ast.SelectField{field} } | FieldList ',' Field { - $$ = append($1.([]*ast.SelectField), $3.(*ast.SelectField)) + + fl := $1.([]*ast.SelectField) + last := fl[len(fl)-1] + if last.Expr != nil && last.AsName.O == "" { + lastEnd := yyS[yypt-1].offset-1 // Comma offset. + last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) + } + newField := $3.(*ast.SelectField) + newField.Offset = yyS[yypt].offset + $$ = append(fl, newField) } GroupByClause: @@ -2628,6 +2647,22 @@ SelectStmt: Fields: $3.(*ast.FieldList), LockTp: $5.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + src := yylex.(*lexer).src + var lastEnd int + if $4 != nil { + lastEnd = yyS[yypt-1].offset-1 + } else if $5 != ast.SelectLockNone { + lastEnd = yyS[yypt].offset-1 + } else { + lastEnd = len(src) + if src[lastEnd-1] == ';' { + lastEnd-- + } + } + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } if $4 != nil { st.Limit = $4.(*ast.Limit) } @@ -2640,6 +2675,11 @@ SelectStmt: Fields: $3.(*ast.FieldList), LockTp: $7.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-3].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } if $5 != nil { st.Where = $5.(ast.ExprNode) } @@ -2659,6 +2699,12 @@ SelectStmt: LockTp: $11.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-7].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } + if $6 != nil { st.Where = $6.(ast.ExprNode) } @@ -2739,6 +2785,8 @@ TableFactor: } | '(' SelectStmt ')' TableAsName { + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } | '(' UnionStmt ')' TableAsName @@ -2753,7 +2801,6 @@ TableFactor: TableAsNameOpt: { $$ = model.CIStr{} - yyS[yypt].offset = yylex.(*lexer).i } | TableAsName { @@ -2886,6 +2933,7 @@ SubSelect: '(' SelectStmt ')' { s := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) @@ -2928,7 +2976,9 @@ UnionStmt: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) - union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) + st := $5.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, st) if $7 != nil { union.OrderBy = $7.(*ast.OrderByClause) } @@ -2958,7 +3008,9 @@ UnionSelect: SelectStmt | '(' SelectStmt ')' { - $$ = $2 + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) + $$ = st } UnionOpt: diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index adaa662651..b58a8f6f3e 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -118,6 +118,16 @@ func (l *lexer) GetCharsetInfo() (string, string) { return l.charset, l.collation } +// The select statement is not at the end of the whole statement, if the last +// field text was set from its offset to the end of the src string, update +// the last field text. +func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 { + lastField.SetText(l.src[lastField.Offset:lastEnd]) + } +} + func (l *lexer) unget(b byte) { l.ungetBuf = append(l.ungetBuf, b) l.i-- diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 528e360754..321cdada18 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -199,6 +199,9 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } + if oldField.AsName == "" { + oldField.AsName = val.Text() + } } else if val.WildCard != nil { str := "*" if val.WildCard.Table.O != "" { From 53a11459c2d408f3d9ea3a7785f8f9ed87eae4a8 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 10:28:33 +0800 Subject: [PATCH 16/25] optimizer: only set old field AsName when it is not identifier --- optimizer/convert_stmt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 321cdada18..875eb62aea 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -199,7 +199,7 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } - if oldField.AsName == "" { + if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { oldField.AsName = val.Text() } } else if val.WildCard != nil { From 5ca94b542e55b182c4c836f5c38671eb24e71dc4 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 11:18:34 +0800 Subject: [PATCH 17/25] optimizer: add comments for set field text. --- optimizer/convert_stmt.go | 1 + 1 file changed, 1 insertion(+) diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 875eb62aea..98459ff694 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -199,6 +199,7 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } + // TODO: handle parenthesesed column name expression, which should not set AsName. if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { oldField.AsName = val.Text() } From a74b373a5271df29bd3da9ab8c24eb5e36ea8548 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Tue, 3 Nov 2015 11:24:47 +0800 Subject: [PATCH 18/25] parser: Add charset info for string literal --- ast/parser/parser.y | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 127a3b965d..da9e510966 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1797,6 +1797,15 @@ Literal: | floatLit | intLit | stringLit + { + tp := types.NewFieldType(mysql.TypeString) + l := yylex.(*lexer) + tp.Charset, tp.Collate = l.GetCharsetInfo() + $$ = &types.DataItem{ + Type: tp, + Data: $1.(string), + } + } | "UNDERSCORE_CHARSET" stringLit { // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html From 6b8a39523f85b96fbf324fd9efd8a952cdd3d845 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 13:53:46 +0800 Subject: [PATCH 19/25] ast/parser: add tests and fix bugs. --- ast/dml.go | 39 +-- ast/functions.go | 47 ++++ ast/parser/parser.y | 79 +++++- ast/parser/parser_test.go | 568 ++++++++++++++++++++++++++++++++++++++ ast/parser/scanner.l | 8 + optimizer/convert_expr.go | 17 ++ optimizer/convert_stmt.go | 26 +- 7 files changed, 734 insertions(+), 50 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index b45fc6374d..f8cbd01c11 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -642,7 +642,7 @@ type DeleteStmt struct { // Only used in multiple table delete statement. Tables []*TableName Where ExprNode - Order []*ByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -680,20 +680,20 @@ func (nod *DeleteStmt) Accept(v Visitor) (Node, bool) { } nod.Where = node.(ExprNode) } - - for i, val := range nod.Order { - node, ok = val.Accept(v) + if nod.Order != nil { + node, ok = nod.Order.Accept(v) if !ok { return nod, false } - nod.Order[i] = node.(*ByItem) + nod.Order = node.(*OrderByClause) } - - node, ok = nod.Limit.Accept(v) - if !ok { - return nod, false + if nod.Limit != nil { + node, ok = nod.Limit.Accept(v) + if !ok { + return nod, false + } + nod.Limit = node.(*Limit) } - nod.Limit = node.(*Limit) return v.Leave(nod) } @@ -705,7 +705,7 @@ type UpdateStmt struct { TableRefs *TableRefsClause List []*Assignment Where ExprNode - Order []*ByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -738,19 +738,20 @@ func (nod *UpdateStmt) Accept(v Visitor) (Node, bool) { } nod.Where = node.(ExprNode) } - - for i, val := range nod.Order { - node, ok = val.Accept(v) + if nod.Order != nil { + node, ok = nod.Order.Accept(v) if !ok { return nod, false } - nod.Order[i] = node.(*ByItem) + nod.Order = node.(*OrderByClause) } - node, ok = nod.Limit.Accept(v) - if !ok { - return nod, false + if nod.Limit != nil { + node, ok = nod.Limit.Accept(v) + if !ok { + return nod, false + } + nod.Limit = node.(*Limit) } - nod.Limit = node.(*Limit) return v.Leave(nod) } diff --git a/ast/functions.go b/ast/functions.go index e3323ef3d1..8d688f9eee 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -28,6 +28,7 @@ var ( _ FuncNode = &FuncSubstringExpr{} _ FuncNode = &FuncLocateExpr{} _ FuncNode = &FuncTrimExpr{} + _ FuncNode = &FuncDateArithExpr{} _ FuncNode = &AggregateFuncExpr{} ) @@ -337,6 +338,52 @@ func (nod *FuncTrimExpr) IsStatic() bool { return nod.Str.IsStatic() && nod.RemStr.IsStatic() } +// DateArithType is type for DateArith option. +type DateArithType byte + +const ( + // DateAdd is to run date_add function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add + DateAdd DateArithType = iota + 1 + // DateSub is to run date_sub function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub + DateSub +) + +// FuncDateArithExpr is the struct for date arithmetic functions. +type FuncDateArithExpr struct { + funcNode + + Op DateArithType + Unit string + Date ExprNode + Interval ExprNode +} + +// Accept implements Node Accept interface. +func (nod *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*FuncDateArithExpr) + if nod.Date != nil { + node, ok := nod.Date.Accept(v) + if !ok { + return nod, false + } + nod.Date = node.(ExprNode) + } + if nod.Interval != nil { + node, ok := nod.Date.Accept(v) + if !ok { + return nod, false + } + nod.Date = node.(ExprNode) + } + return v.Leave(nod) +} + // AggregateFuncExpr represents aggregate function expression. type AggregateFuncExpr struct { funcNode diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 1d6b05e314..4175b2c1f8 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -104,6 +104,8 @@ import ( currentUser "CURRENT_USER" database "DATABASE" databases "DATABASES" + dateAdd "DATE_ADD" + dateSub "DATE_SUB" day "DAY" dayofmonth "DAYOFMONTH" dayofweek "DAYOFWEEK" @@ -156,6 +158,7 @@ import ( index "INDEX" inner "INNER" insert "INSERT" + interval "INTERVAL" into "INTO" is "IS" join "JOIN" @@ -839,7 +842,7 @@ ColumnOption: { // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html // The CHECK clause is parsed but ignored by all storage engines. - $$ = nil + $$ = &ast.ColumnOption{} } ColumnOptionList: @@ -1159,7 +1162,7 @@ DeleteFromStmt: x.Where = $7.(ast.ExprNode) } if $8 != nil { - x.Order = $8.([]*ast.ByItem) + x.Order = $8.(*ast.OrderByClause) } if $9 != nil { x.Limit = $9.(*ast.Limit) @@ -1650,7 +1653,7 @@ UnReservedKeyword: | "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" + "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" | "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" @@ -1689,12 +1692,17 @@ InsertValues: { $$ = &ast.InsertStmt{ Columns: $2.([]*ast.ColumnName), - Lists: $5.([][]ast.ExprNode)} + Lists: $5.([][]ast.ExprNode), + } } | '(' ColumnNameListOpt ')' SelectStmt { $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} } +| '(' ColumnNameListOpt ')' UnionStmt + { + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.UnionStmt)} + } | ValueSym ExpressionListList %prec insertValues { $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} @@ -2111,6 +2119,24 @@ FunctionCallNonKeyword: { $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } +| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateAdd, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), + } + } +| "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateSub, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), + } + } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { $$ = &ast.FuncExtractExpr{ @@ -2969,6 +2995,8 @@ UnionStmt: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) $$ = union } @@ -2976,6 +3004,8 @@ UnionStmt: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) st := $5.(*ast.SelectStmt) yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) union.Selects = append(union.Selects, st) @@ -3000,6 +3030,8 @@ UnionClauseList: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) $$ = union } @@ -3216,7 +3248,7 @@ ShowStmt: } if x, ok := $4.(*ast.PatternLikeExpr); ok { stmt.Pattern = x - } else { + } else if $4 != nil { stmt.Where = $4.(ast.ExprNode) } $$ = stmt @@ -3228,7 +3260,7 @@ ShowStmt: } if x, ok := $3.(*ast.PatternLikeExpr); ok { stmt.Pattern = x - } else { + } else if $3 != nil { stmt.Where = $3.(ast.ExprNode) } $$ = stmt @@ -3970,7 +4002,34 @@ StringName: * See: https://dev.mysql.com/doc/refman/5.7/en/update.html ***********************************************************************************/ UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + { + var refs *ast.Join + if x, ok := $4.(*ast.Join); ok { + refs = x + } else { + refs = &ast.Join{Left: $4.(ast.ResultSetNode)} + } + st := &ast.UpdateStmt{ + LowPriority: $2.(bool), + TableRefs: &ast.TableRefsClause{TableRefs: refs}, + List: $6.([]*ast.Assignment), + } + if $7 != nil { + st.Where = $7.(ast.ExprNode) + } + if $8 != nil { + st.Order = $8.(*ast.OrderByClause) + } + if $9 != nil { + st.Limit = $9.(*ast.Limit) + } + $$ = st + if yylex.(*lexer).root { + break + } + } +| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional { st := &ast.UpdateStmt{ LowPriority: $2.(bool), @@ -3980,12 +4039,6 @@ UpdateStmt: if $7 != nil { st.Where = $7.(ast.ExprNode) } - if $8 != nil { - st.Order = $8.([]*ast.ByItem) - } - if $9 != nil { - st.Limit = $9.(*ast.Limit) - } $$ = st if yylex.(*lexer).root { break diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go index 1e2798c28c..3e8d7bb278 100644 --- a/ast/parser/parser_test.go +++ b/ast/parser/parser_test.go @@ -77,3 +77,571 @@ func (s *testParserSuite) TestSimple(c *C) { c.Assert(ok, IsTrue) } } + +type testCase struct { + src string + ok bool +} + +func (s *testParserSuite) RunTest(c *C, table []testCase) { + for _, t := range table { + l := NewLexer(t.src) + ok := yyParse(l) == 0 + c.Assert(ok, Equals, t.ok, Commentf("source %v %v", t.src, l.errs)) + switch ok { + case true: + c.Assert(l.errs, HasLen, 0, Commentf("src: %s", t.src)) + case false: + c.Assert(len(l.errs), Not(Equals), 0, Commentf("src: %s", t.src)) + } + } +} +func (s *testParserSuite) TestDMLStmt(c *C) { + table := []testCase{ + {"", true}, + {";", true}, + {"INSERT INTO foo VALUES (1234)", true}, + {"INSERT INTO foo VALUES (1234, 5678)", true}, + // 15 + {"INSERT INTO foo VALUES (1 || 2)", true}, + {"INSERT INTO foo VALUES (1 | 2)", true}, + {"INSERT INTO foo VALUES (false || true)", true}, + {"INSERT INTO foo VALUES (bar(5678))", false}, + // 20 + {"INSERT INTO foo VALUES ()", true}, + {"SELECT * FROM t", true}, + {"SELECT * FROM t AS u", true}, + // 25 + {"SELECT * FROM t, v", true}, + {"SELECT * FROM t AS u, v", true}, + {"SELECT * FROM t, v AS w", true}, + {"SELECT * FROM t AS u, v AS w", true}, + {"SELECT * FROM foo, bar, foo", true}, + // 30 + {"SELECT DISTINCTS * FROM t", false}, + {"SELECT DISTINCT * FROM t", true}, + {"INSERT INTO foo (a) VALUES (42)", true}, + {"INSERT INTO foo (a,) VALUES (42,)", false}, + // 35 + {"INSERT INTO foo (a,b) VALUES (42,314)", true}, + {"INSERT INTO foo (a,b,) VALUES (42,314)", false}, + {"INSERT INTO foo (a,b,) VALUES (42,314,)", false}, + {"INSERT INTO foo () VALUES ()", true}, + {"INSERT INTO foo VALUE ()", true}, + + {"REPLACE INTO foo VALUES (1 || 2)", true}, + {"REPLACE INTO foo VALUES (1 | 2)", true}, + {"REPLACE INTO foo VALUES (false || true)", true}, + {"REPLACE INTO foo VALUES (bar(5678))", false}, + {"REPLACE INTO foo VALUES ()", true}, + {"REPLACE INTO foo (a,b) VALUES (42,314)", true}, + {"REPLACE INTO foo (a,b,) VALUES (42,314)", false}, + {"REPLACE INTO foo (a,b,) VALUES (42,314,)", false}, + {"REPLACE INTO foo () VALUES ()", true}, + {"REPLACE INTO foo VALUE ()", true}, + // 40 + {`SELECT stuff.id + FROM stuff + WHERE stuff.value >= ALL (SELECT stuff.value + FROM stuff)`, true}, + {"BEGIN", true}, + {"START TRANSACTION", true}, + // 45 + {"COMMIT", true}, + {"ROLLBACK", true}, + {` + BEGIN; + INSERT INTO foo VALUES (42, 3.14); + INSERT INTO foo VALUES (-1, 2.78); + COMMIT;`, true}, + {` // A + BEGIN; + INSERT INTO tmp SELECT * from bar; + SELECT * from tmp; + + // B + ROLLBACK;`, true}, + + // set + // user defined + {"SET @a = 1", true}, + // session system variables + {"SET SESSION autocommit = 1", true}, + {"SET @@session.autocommit = 1", true}, + {"SET LOCAL autocommit = 1", true}, + {"SET @@local.autocommit = 1", true}, + {"SET @@autocommit = 1", true}, + {"SET autocommit = 1", true}, + // global system variables + {"SET GLOBAL autocommit = 1", true}, + {"SET @@global.autocommit = 1", true}, + // SET CHARACTER SET + {"SET CHARACTER SET utf8mb4;", true}, + {"SET CHARACTER SET 'utf8mb4';", true}, + // Set password + {"SET PASSWORD = 'password';", true}, + {"SET PASSWORD FOR 'root'@'localhost' = 'password';", true}, + + // qualified select + {"SELECT a.b.c FROM t", true}, + {"SELECT a.b.*.c FROM t", false}, + {"SELECT a.b.* FROM t", true}, + {"SELECT a FROM t", true}, + {"SELECT a.b.c.d FROM t", false}, + + // Do statement + {"DO 1", true}, + {"DO 1 from t", false}, + + // Select for update + {"SELECT * from t for update", true}, + {"SELECT * from t lock in share mode", true}, + + // For alter table + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED", true}, + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED FIRST", true}, + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED AFTER b", true}, + + // from join + {"SELECT * from t1, t2, t3", true}, + {"select * from t1 join t2 left join t3 on t2.id = t3.id", true}, + {"select * from t1 right join t2 on t1.id = t2.id left join t3 on t3.id = t2.id", true}, + {"select * from t1 right join t2 on t1.id = t2.id left join t3", false}, + + // For show full columns + {"show columns in t;", true}, + {"show full columns in t;", true}, + + // For set names + {"set names utf8", true}, + {"set names utf8 collate utf8_unicode_ci", true}, + + // For show character set + {"show character set;", true}, + // For on duplicate key update + {"INSERT INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, + {"INSERT IGNORE INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, + + // For SHOW statement + {"SHOW VARIABLES LIKE 'character_set_results'", true}, + {"SHOW GLOBAL VARIABLES LIKE 'character_set_results'", true}, + {"SHOW SESSION VARIABLES LIKE 'character_set_results'", true}, + {"SHOW VARIABLES", true}, + {"SHOW GLOBAL VARIABLES", true}, + {"SHOW GLOBAL VARIABLES WHERE Variable_name = 'autocommit'", true}, + {`SHOW FULL TABLES FROM icar_qa LIKE play_evolutions`, true}, + {`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true}, + {`SHOW GRANTS`, true}, + {`SHOW GRANTS FOR 'test'@'localhost'`, true}, + + // For default value + {"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true}, + + // For delete statement + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, + + // For update statement + {"UPDATE t SET id = id + 1 ORDER BY id DESC;", true}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id;", true}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id LIMIT 10;", false}, + {"UPDATE user T0 LEFT OUTER JOIN user_profile T1 ON T1.id = T0.profile_id SET T0.profile_id = 1 WHERE T0.profile_id IN (1);", true}, + + // For select with where clause + {"SELECT * FROM t WHERE 1 = 1", true}, + + // For show collation + {"show collation", true}, + {"show collation like 'utf8%'", true}, + {"show collation where Charset = 'utf8' and Collation = 'utf8_bin'", true}, + + // For dual + {"select 1 from dual", true}, + {"select 1 from dual limit 1", true}, + {"select 1 where exists (select 2)", false}, + {"select 1 from dual where not exists (select 2)", true}, + + // For show create table + {"show create table test.t", true}, + {"show create table t", true}, + + // For https://github.com/pingcap/tidb/issues/320 + {`(select 1);`, true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestExpression(c *C) { + table := []testCase{ + // Sign expression + {"SELECT ++1", true}, + {"SELECT -*1", false}, + {"SELECT -+1", true}, + {"SELECT -1", true}, + {"SELECT --1", true}, + + // For string literal + {`select '''a''', """a"""`, true}, + {`select ''a''`, false}, + {`select ""a""`, false}, + {`select '''a''';`, true}, + {`select '\'a\'';`, true}, + {`select "\"a\"";`, true}, + {`select """a""";`, true}, + {`select _utf8"string";`, true}, + // For comparison + {"select 1 <=> 0, 1 <=> null, 1 = null", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestBuiltin(c *C) { + table := []testCase{ + // For buildin functions + {"SELECT DAYOFMONTH('2007-02-03');", true}, + {"SELECT RAND();", true}, + {"SELECT RAND(1);", true}, + + {"SELECT SUBSTRING('Quadratically',5);", true}, + {"SELECT SUBSTRING('Quadratically',5, 3);", true}, + {"SELECT SUBSTRING('Quadratically' FROM 5);", true}, + {"SELECT SUBSTRING('Quadratically' FROM 5 FOR 3);", true}, + + {"SELECT CONVERT('111', SIGNED);", true}, + + {"SELECT DATABASE();", true}, + {"SELECT USER();", true}, + {"SELECT CURRENT_USER();", true}, + {"SELECT CURRENT_USER;", true}, + + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', 2);", true}, + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', -2);", true}, + + {`SELECT LOWER("A"), UPPER("a")`, true}, + + {`SELECT REPLACE('www.mysql.com', 'w', 'Ww')`, true}, + + {`SELECT LOCATE('bar', 'foobarbar');`, true}, + {`SELECT LOCATE('bar', 'foobarbar', 5);`, true}, + + {"select current_date, current_date(), curdate()", true}, + + // For delete statement + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, + + // For time fsp + {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, + + // For row + {"select row(1)", false}, + {"select row(1, 1,)", false}, + {"select (1, 1,)", false}, + {"select row(1, 1) > row(1, 1), row(1, 1, 1) > row(1, 1, 1)", true}, + {"Select (1, 1) > (1, 1)", true}, + {"create table t (row int)", true}, + + // For cast with charset + {"SELECT *, CAST(data AS CHAR CHARACTER SET utf8) FROM t;", true}, + + // For binary operator + {"SELECT binary 'a';", true}, + + // Select time + {"select current_timestamp", true}, + {"select current_timestamp()", true}, + {"select current_timestamp(6)", true}, + {"select now()", true}, + {"select now(6)", true}, + {"select sysdate(), sysdate(6)", true}, + + // For time extract + {`select extract(microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(week from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(month from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(quarter from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(year from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(second_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_hour from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(year_month from "2011-11-11 10:10:10.123456")`, true}, + + // For issue 224 + {`SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8) COLLATE utf8_bin;`, true}, + + // For trim + {`SELECT TRIM(' bar ');`, true}, + {`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true}, + {`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true}, + {`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true}, + + // For date_add + {`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + + // For date_sub + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestIdentifier(c *C) { + table := []testCase{ + // For quote identifier + {"select `a`, `a.b`, `a b` from t", true}, + // For unquoted identifier + {"create table MergeContextTest$Simple (value integer not null, primary key (value))", true}, + // For as + {"select 1 as a, 1 as `a`, 1 as \"a\", 1 as 'a'", true}, + {`select 1 as a, 1 as "a", 1 as 'a'`, true}, + {`select 1 a, 1 "a", 1 'a'`, true}, + {`select * from t as "a"`, false}, + {`select * from t a`, true}, + {`select * from t as a`, true}, + {"select 1 full, 1 row, 1 abs", true}, + {"select * from t full, t1 row, t2 abs", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestDDL(c *C) { + table := []testCase{ + {"CREATE", false}, + {"CREATE TABLE", false}, + {"CREATE TABLE foo (", false}, + {"CREATE TABLE foo ()", false}, + {"CREATE TABLE foo ();", false}, + {"CREATE TABLE foo (a TINYINT UNSIGNED);", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, + {"CREATE TABLE foo (a bigint unsigned, b bool);", true}, + {"CREATE TABLE foo (a TINYINT, b SMALLINT) CREATE TABLE bar (x INT, y int64)", false}, + {"CREATE TABLE foo (a int, b float); CREATE TABLE bar (x double, y float)", true}, + {"CREATE TABLE foo (a bytes)", false}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) -- foo", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) // foo", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, + {"CREATE TABLE foo /* foo */ (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, + {"CREATE TABLE foo (name CHAR(50) BINARY)", true}, + {"CREATE TABLE foo (name CHAR(50) COLLATE utf8_bin)", true}, + {"CREATE TABLE foo (name CHAR(50) CHARACTER SET utf8)", true}, + {"CREATE TABLE foo (name CHAR(50) BINARY CHARACTER SET utf8 COLLATE utf8_bin)", true}, + + {"CREATE TABLE foo (a.b, b);", false}, + {"CREATE TABLE foo (a, b.c);", false}, + // For table option + {"create table t (c int) avg_row_length = 3", true}, + {"create table t (c int) avg_row_length 3", true}, + {"create table t (c int) checksum = 0", true}, + {"create table t (c int) checksum 1", true}, + {"create table t (c int) compression = none", true}, + {"create table t (c int) compression lz4", true}, + {"create table t (c int) connection = 'abc'", true}, + {"create table t (c int) connection 'abc'", true}, + {"create table t (c int) key_block_size = 1024", true}, + {"create table t (c int) key_block_size 1024", true}, + {"create table t (c int) max_rows = 1000", true}, + {"create table t (c int) max_rows 1000", true}, + {"create table t (c int) min_rows = 1000", true}, + {"create table t (c int) min_rows 1000", true}, + {"create table t (c int) password = 'abc'", true}, + {"create table t (c int) password 'abc'", true}, + // For check clause + {"create table t (c1 bool, c2 bool, check (c1 in (0, 1)), check (c2 in (0, 1)))", true}, + {"CREATE TABLE Customer (SD integer CHECK (SD > 0), First_Name varchar(30));", true}, + + {"create database xxx", true}, + {"create database if exists xxx", false}, + {"create database if not exists xxx", true}, + {"create schema xxx", true}, + {"create schema if exists xxx", false}, + {"create schema if not exists xxx", true}, + // For drop datbase/schema/table + {"drop database xxx", true}, + {"drop database if exists xxx", true}, + {"drop database if not exists xxx", false}, + {"drop schema xxx", true}, + {"drop schema if exists xxx", true}, + {"drop schema if not exists xxx", false}, + {"drop table xxx", true}, + {"drop table xxx, yyy", true}, + {"drop tables xxx", true}, + {"drop tables xxx, yyy", true}, + {"drop table if exists xxx", true}, + {"drop table if not exists xxx", false}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestType(c *C) { + table := []testCase{ + // For time fsp + {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, + + // For hexadecimal + {"SELECT x'0a', X'11', 0x11", true}, + {"select x'0xaa'", false}, + {"select 0X11", false}, + + // For bit + {"select 0b01, 0b0, b'11', B'11'", true}, + {"select 0B01", false}, + {"select 0b21", false}, + + // For enum and set type + {"create table t (c1 enum('a', 'b'), c2 set('a', 'b'))", true}, + {"create table t (c1 enum)", false}, + {"create table t (c1 set)", false}, + + // For blob and text field length + {"create table t (c1 blob(1024), c2 text(1024))", true}, + + // For year + {"create table t (y year(4), y1 year)", true}, + + // For national + {"create table t (c1 national char(2), c2 national varchar(2))", true}, + + // For https://github.com/pingcap/tidb/issues/312 + {`create table t (c float(53));`, true}, + {`create table t (c float(54));`, false}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestPrivilege(c *C) { + table := []testCase{ + // For create user + {`CREATE USER IF NOT EXISTS 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password', 'root'@'127.0.0.1' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + + // For grant statement + {"GRANT ALL ON db1.* TO 'jeffrey'@'localhost';", true}, + {"GRANT SELECT ON db2.invoice TO 'jeffrey'@'localhost';", true}, + {"GRANT ALL ON *.* TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON *.* TO 'someuser'@'somehost';", true}, + {"GRANT ALL ON mydb.* TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON mydb.* TO 'someuser'@'somehost';", true}, + {"GRANT ALL ON mydb.mytbl TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON mydb.mytbl TO 'someuser'@'somehost';", true}, + {"GRANT SELECT (col1), INSERT (col1,col2) ON mydb.mytbl TO 'someuser'@'somehost';", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestComment(c *C) { + table := []testCase{ + {"create table t (c int comment 'comment')", true}, + {"create table t (c int) comment = 'comment'", true}, + {"create table t (c int) comment 'comment'", true}, + {"create table t (c int) comment comment", false}, + {"create table t (comment text)", true}, + // For comment in query + {"/*comment*/ /*comment*/ select c /* this is a comment */ from t;", true}, + } + s.RunTest(c, table) +} +func (s *testParserSuite) TestSubquery(c *C) { + table := []testCase{ + // For compare subquery + {"SELECT 1 > (select 1)", true}, + {"SELECT 1 > ANY (select 1)", true}, + {"SELECT 1 > ALL (select 1)", true}, + {"SELECT 1 > SOME (select 1)", true}, + + // For exists subquery + {"SELECT EXISTS select 1", false}, + {"SELECT EXISTS (select 1)", true}, + {"SELECT + EXISTS (select 1)", true}, + {"SELECT - EXISTS (select 1)", true}, + {"SELECT NOT EXISTS (select 1)", true}, + {"SELECT + NOT EXISTS (select 1)", false}, + {"SELECT - NOT EXISTS (select 1)", false}, + } + s.RunTest(c, table) +} +func (s *testParserSuite) TestUnion(c *C) { + table := []testCase{ + {"select c1 from t1 union select c2 from t2", true}, + {"select c1 from t1 union (select c2 from t2)", true}, + {"select c1 from t1 union (select c2 from t2) order by c1", true}, + {"select c1 from t1 union select c2 from t2 order by c2", true}, + {"select c1 from t1 union (select c2 from t2) limit 1", true}, + {"select c1 from t1 union (select c2 from t2) limit 1, 1", true}, + {"select c1 from t1 union (select c2 from t2) order by c1 limit 1", true}, + {"(select c1 from t1) union distinct select c2 from t2", true}, + {"(select c1 from t1) union all select c2 from t2", true}, + {"(select c1 from t1) union (select c2 from t2) order by c1 union select c3 from t3", false}, + {"(select c1 from t1) union (select c2 from t2) limit 1 union select c3 from t3", false}, + {"(select c1 from t1) union select c2 from t2 union (select c3 from t3) order by c1 limit 1", true}, + {"select (select 1 union select 1) as a", true}, + {"select * from (select 1 union select 2) as a", true}, + {"insert into t select c1 from t1 union select c2 from t2", true}, + {"insert into t (c) select c1 from t1 union select c2 from t2", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestLikeEscape(c *C) { + table := []testCase{ + // For like escape + {`select "abc_" like "abc\\_" escape ''`, true}, + {`select "abc_" like "abc\\_" escape '\\'`, true}, + {`select "abc_" like "abc\\_" escape '||'`, false}, + {`select "abc" like "escape" escape '+'`, true}, + } + + s.RunTest(c, table) +} diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index b58a8f6f3e..9ee7b1b8a9 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -307,6 +307,8 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} database {d}{a}{t}{a}{b}{a}{s}{e} databases {d}{a}{t}{a}{b}{a}{s}{e}{s} +date_add {d}{a}{t}{e}_{a}{d}{d} +date_sub {d}{a}{t}{e}_{s}{u}{b} day {d}{a}{y} dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} @@ -355,6 +357,7 @@ in {i}{n} index {i}{n}{d}{e}{x} inner {i}{n}{n}{e}{r} insert {i}{n}{s}{e}{r}{t} +interval {i}{n}{t}{e}{r}{v}{a}{l} into {i}{n}{t}{o} is {i}{s} join {j}{o}{i}{n} @@ -655,6 +658,10 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {database} lval.item = string(l.val) return database {databases} return databases +{date_add} lval.item = string(l.val) + return dateAdd +{date_sub} lval.item = string(l.val) + return dateSub {day} lval.item = string(l.val) return day {dayofweek} lval.item = string(l.val) @@ -738,6 +745,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {index} return index {inner} return inner {insert} return insert +{interval} return interval {into} return into {in} return in {is} return is diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index d43f93f016..a136819ae1 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -112,6 +112,8 @@ func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) { c.funcLocate(v) case *ast.FuncTrimExpr: c.funcTrim(v) + case *ast.FuncDateArithExpr: + c.funcDateArith(v) case *ast.AggregateFuncExpr: c.aggregateFunc(v) } @@ -396,6 +398,21 @@ func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) { c.exprMap[v] = oldTrim } +func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) { + oldDateArith := &expression.DateArith{ + Unit: v.Unit, + Date: c.exprMap[v.Date], + Interval: c.exprMap[v.Interval], + } + switch v.Op { + case ast.DateAdd: + oldDateArith.Op = expression.DateAdd + case ast.DateSub: + oldDateArith.Op = expression.DateSub + } + c.exprMap[v] = oldDateArith +} + func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) { oldAggregate := &expression.Call{ F: v.F, diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 98459ff694..df769237f1 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -123,16 +123,11 @@ func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.De } } if v.Order != nil { - orderRset := &rsets.OrderByRset{} - for _, val := range v.Order { - oldExpr, err := convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) - } - orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} - orderRset.By = append(orderRset.By, orderItem) + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) } - oldDelete.Order = orderRset + oldDelete.Order = oldOrderBy } if v.Limit != nil { oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} @@ -167,16 +162,11 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up oldUpdate.List = append(oldUpdate.List, oldAssign) } if v.Order != nil { - orderRset := &rsets.OrderByRset{} - for _, val := range v.Order { - oldExpr, err := convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) - } - orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} - orderRset.By = append(orderRset.By, orderItem) + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) } - oldUpdate.Order = orderRset + oldUpdate.Order = oldOrderBy } if v.Limit != nil { oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} From fce519486f0d2ba45a8554de630faf2f6ab74d62 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 4 Nov 2015 17:00:38 +0800 Subject: [PATCH 20/25] ast: visit Select element first in InsertStmt. Address comment. --- ast/dml.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index f8cbd01c11..0b35a70d0e 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -586,6 +586,13 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { } nod = newNod.(*InsertStmt) + if nod.Select != nil { + node, ok := nod.Select.Accept(v) + if !ok { + return nod, false + } + nod.Select = node.(ResultSetNode) + } node, ok := nod.Table.Accept(v) if !ok { return nod, false @@ -622,13 +629,6 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { } nod.OnDuplicate[i] = node.(*Assignment) } - if nod.Select != nil { - node, ok := nod.Select.Accept(v) - if !ok { - return nod, false - } - nod.Select = node.(ResultSetNode) - } return v.Leave(nod) } From b8a7e686317b98a650a4e66fb5c8d0c3fdec0c55 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 4 Nov 2015 17:59:12 +0800 Subject: [PATCH 21/25] ast: add comments about implementation rule on Accept method. --- ast/ast.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ast/ast.go b/ast/ast.go index 6b699d796e..a10d99dbd7 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -26,6 +26,11 @@ type Node interface { // Accept accepts Visitor to visit itself. // The returned node should replace original node. // ok returns false to stop visiting. + // + // Implementation of this method should first call visitor.Enter, + // assign the returned node to its method receiver, if skipChildren returns true, + // children should be skipped. Otherwise, call its children in particular order that + // later elements depends on on former elements. Finally, return visitor.Leave. Accept(v Visitor) (node Node, ok bool) // Text returns the original text of the element. Text() string From 524e18dff1adb8a0cc4af1479f9bd94bfa6376d5 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 4 Nov 2015 19:51:33 +0800 Subject: [PATCH 22/25] ast: address comment. --- ast/ast.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index a10d99dbd7..a1021420f1 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -102,12 +102,12 @@ type ResultSetNode interface { // Visitor visits a Node. type Visitor interface { // VisitEnter is called before children nodes is visited. - // The returned node should be the same type as the input node n. + // The returned node must be the same type as the input node n. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. - // The returned node should be the same type as the input node n. + // The returned node must be the same type as the input node n. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) } From c623fafc727627eac306bc34f8dc44ed3323b441 Mon Sep 17 00:00:00 2001 From: ngaut Date: Thu, 5 Nov 2015 13:53:32 +0800 Subject: [PATCH 23/25] parser: clean up --- ast/parser/parser.y | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index f61843b28d..145d5ae0e7 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -2696,7 +2696,7 @@ SelectStmt: lastEnd-- } } - lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + lastField.SetText(src[lastField.Offset:lastEnd]) } if $4 != nil { st.Limit = $4.(*ast.Limit) From 9be82dbe27e10f528aa913c1c34dc0df3d55ebda Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 5 Nov 2015 12:44:18 +0800 Subject: [PATCH 24/25] ast: address comment. --- ast/ast.go | 2 +- ast/cloner.go | 8 ++++---- ast/parser/parser.y | 2 +- optimizer/aggregator.go | 4 ++-- optimizer/typecomputer.go | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index a1021420f1..c38c8bc32d 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -30,7 +30,7 @@ type Node interface { // Implementation of this method should first call visitor.Enter, // assign the returned node to its method receiver, if skipChildren returns true, // children should be skipped. Otherwise, call its children in particular order that - // later elements depends on on former elements. Finally, return visitor.Leave. + // later elements depends on former elements. Finally, return visitor.Leave. Accept(v Visitor) (node Node, ok bool) // Text returns the original text of the element. Text() string diff --git a/ast/cloner.go b/ast/cloner.go index be5d7e441e..d3d8d282e5 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -21,7 +21,7 @@ type Cloner struct { // Enter implements Visitor Enter interface. func (c *Cloner) Enter(node Node) (Node, bool) { - return cloneStruct(node), false + return copyStruct(node), false } // Leave implements Visitor Leave interface. @@ -29,9 +29,9 @@ func (c *Cloner) Leave(in Node) (out Node, ok bool) { return in, true } -// cloneStruct clone a node's struct value, if the struct has slice -// the cloned value should make a new slice and copy old slice to new slice. -func cloneStruct(in Node) (out Node) { +// copyStruct copies a node's struct value, if the struct has slice member, +// make a new slice and copy old slice value to new slice. +func copyStruct(in Node) (out Node) { switch v := in.(type) { case *ValueExpr: nv := *v diff --git a/ast/parser/parser.y b/ast/parser/parser.y index f61843b28d..6997d1cfbf 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -449,7 +449,7 @@ import ( OrderBy "ORDER BY clause" ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - ByList "BY list" + ByList "BY list" OuterOpt "optional OUTER clause" QuickOptional "QUICK or empty" PasswordOpt "Password option" diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go index 29b2fa3ea3..452d047c47 100644 --- a/optimizer/aggregator.go +++ b/optimizer/aggregator.go @@ -16,10 +16,10 @@ package optimizer // Aggregator is the interface to // compute aggregate function result. type Aggregator interface { - // Input add a input value to aggregator. + // Input adds an input value to aggregator. // The input values are accumulated in the aggregator. Input(in ...interface{}) error - // Output use input values to compute the aggregated result. + // Output uses input values to compute the aggregated result. Output() interface{} // Clear clears the input values. Clear() diff --git a/optimizer/typecomputer.go b/optimizer/typecomputer.go index d1862d1380..f2cd35aff0 100644 --- a/optimizer/typecomputer.go +++ b/optimizer/typecomputer.go @@ -16,7 +16,7 @@ package optimizer import "github.com/pingcap/tidb/ast" // typeComputer is an ast Visitor that -// Compute types for ast.ExprNode. +// computes result type for ast.ExprNode. type typeComputer struct { err error } From 2f2e31c8f47c4b92d7ddd546a6534b74c6118e6d Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 5 Nov 2015 14:42:15 +0800 Subject: [PATCH 25/25] ast: rename ast node receiver 'nod' to 'n'. --- ast/ddl.go | 272 +++++++++++++-------------- ast/dml.go | 452 ++++++++++++++++++++++----------------------- ast/expressions.go | 440 +++++++++++++++++++++---------------------- ast/functions.go | 220 +++++++++++----------- ast/misc.go | 216 +++++++++++----------- 5 files changed, 800 insertions(+), 800 deletions(-) diff --git a/ast/ddl.go b/ast/ddl.go index 87a5ffba70..4f655d6f01 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -70,13 +70,13 @@ type CreateDatabaseStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateDatabaseStmt) - return v.Leave(nod) + n = newNod.(*CreateDatabaseStmt) + return v.Leave(n) } // DropDatabaseStmt is a statement to drop a database and all tables in the database. @@ -89,13 +89,13 @@ type DropDatabaseStmt struct { } // Accept implements Node Accept interface. -func (nod *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DropDatabaseStmt) - return v.Leave(nod) + n = newNod.(*DropDatabaseStmt) + return v.Leave(n) } // IndexColName is used for parsing index column name from SQL. @@ -107,18 +107,18 @@ type IndexColName struct { } // Accept implements Node Accept interface. -func (nod *IndexColName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IndexColName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IndexColName) - node, ok := nod.Column.Accept(v) + n = newNod.(*IndexColName) + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) - return v.Leave(nod) + n.Column = node.(*ColumnName) + return v.Leave(n) } // ReferenceDef is used for parsing foreign key reference option from SQL. @@ -131,25 +131,25 @@ type ReferenceDef struct { } // Accept implements Node Accept interface. -func (nod *ReferenceDef) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ReferenceDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ReferenceDef) - node, ok := nod.Table.Accept(v) + n = newNod.(*ReferenceDef) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(nod) + return v.Leave(n) } // ColumnOptionType is the type for ColumnOption. @@ -183,20 +183,20 @@ type ColumnOption struct { } // Accept implements Node Accept interface. -func (nod *ColumnOption) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnOption) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnOption) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*ColumnOption) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // ConstraintType is the type for Constraint. @@ -230,27 +230,27 @@ type Constraint struct { } // Accept implements Node Accept interface. -func (nod *Constraint) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Constraint) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Constraint) - for i, val := range nod.Keys { + n = newNod.(*Constraint) + for i, val := range n.Keys { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Keys[i] = node.(*IndexColName) + n.Keys[i] = node.(*IndexColName) } - if nod.Refer != nil { - node, ok := nod.Refer.Accept(v) + if n.Refer != nil { + node, ok := n.Refer.Accept(v) if !ok { - return nod, false + return n, false } - nod.Refer = node.(*ReferenceDef) + n.Refer = node.(*ReferenceDef) } - return v.Leave(nod) + return v.Leave(n) } // ColumnDef is used for parsing column definition from SQL. @@ -263,25 +263,25 @@ type ColumnDef struct { } // Accept implements Node Accept interface. -func (nod *ColumnDef) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnDef) - node, ok := nod.Name.Accept(v) + n = newNod.(*ColumnDef) + node, ok := n.Name.Accept(v) if !ok { - return nod, false + return n, false } - nod.Name = node.(*ColumnName) - for i, val := range nod.Options { + n.Name = node.(*ColumnName) + for i, val := range n.Options { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Options[i] = node.(*ColumnOption) + n.Options[i] = node.(*ColumnOption) } - return v.Leave(nod) + return v.Leave(n) } // CreateTableStmt is a statement to create a table. @@ -297,32 +297,32 @@ type CreateTableStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateTableStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*CreateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.Cols { + n.Table = node.(*TableName) + for i, val := range n.Cols { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Cols[i] = node.(*ColumnDef) + n.Cols[i] = node.(*ColumnDef) } - for i, val := range nod.Constraints { + for i, val := range n.Constraints { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Constraints[i] = node.(*Constraint) + n.Constraints[i] = node.(*Constraint) } - return v.Leave(nod) + return v.Leave(n) } // DropTableStmt is a statement to drop one or more tables. @@ -335,20 +335,20 @@ type DropTableStmt struct { } // Accept implements Node Accept interface. -func (nod *DropTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DropTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DropTableStmt) - for i, val := range nod.Tables { + n = newNod.(*DropTableStmt) + for i, val := range n.Tables { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - return v.Leave(nod) + return v.Leave(n) } // CreateIndexStmt is a statement to create an index. @@ -363,25 +363,25 @@ type CreateIndexStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateIndexStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*CreateIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(nod) + return v.Leave(n) } // DropIndexStmt is a statement to drop the index. @@ -395,18 +395,18 @@ type DropIndexStmt struct { } // Accept implements Node Accept interface. -func (nod *DropIndexStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DropIndexStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*DropIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - return v.Leave(nod) + n.Table = node.(*TableName) + return v.Leave(n) } // TableOptionType is the type for TableOption @@ -457,20 +457,20 @@ type ColumnPosition struct { } // Accept implements Node Accept interface. -func (nod *ColumnPosition) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnPosition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnPosition) - if nod.RelativeColumn != nil { - node, ok := nod.RelativeColumn.Accept(v) + n = newNod.(*ColumnPosition) + if n.RelativeColumn != nil { + node, ok := n.RelativeColumn.Accept(v) if !ok { - return nod, false + return n, false } - nod.RelativeColumn = node.(*ColumnName) + n.RelativeColumn = node.(*ColumnName) } - return v.Leave(nod) + return v.Leave(n) } // AlterTableType is the type for AlterTableSpec. @@ -503,41 +503,41 @@ type AlterTableSpec struct { } // Accept implements Node Accept interface. -func (nod *AlterTableSpec) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*AlterTableSpec) - if nod.Constraint != nil { - node, ok := nod.Constraint.Accept(v) + n = newNod.(*AlterTableSpec) + if n.Constraint != nil { + node, ok := n.Constraint.Accept(v) if !ok { - return nod, false + return n, false } - nod.Constraint = node.(*Constraint) + n.Constraint = node.(*Constraint) } - if nod.Column != nil { - node, ok := nod.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnDef) + n.Column = node.(*ColumnDef) } - if nod.DropColumn != nil { - node, ok := nod.DropColumn.Accept(v) + if n.DropColumn != nil { + node, ok := n.DropColumn.Accept(v) if !ok { - return nod, false + return n, false } - nod.DropColumn = node.(*ColumnName) + n.DropColumn = node.(*ColumnName) } - if nod.Position != nil { - node, ok := nod.Position.Accept(v) + if n.Position != nil { + node, ok := n.Position.Accept(v) if !ok { - return nod, false + return n, false } - nod.Position = node.(*ColumnPosition) + n.Position = node.(*ColumnPosition) } - return v.Leave(nod) + return v.Leave(n) } // AlterTableStmt is a statement to change the structure of a table. @@ -550,25 +550,25 @@ type AlterTableStmt struct { } // Accept implements Node Accept interface. -func (nod *AlterTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*AlterTableStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*AlterTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.Specs { + n.Table = node.(*TableName) + for i, val := range n.Specs { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Specs[i] = node.(*AlterTableSpec) + n.Specs[i] = node.(*AlterTableSpec) } - return v.Leave(nod) + return v.Leave(n) } // TruncateTableStmt is a statement to empty a table completely. @@ -580,16 +580,16 @@ type TruncateTableStmt struct { } // Accept implements Node Accept interface. -func (nod *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TruncateTableStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*TruncateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - return v.Leave(nod) + n.Table = node.(*TableName) + return v.Leave(n) } diff --git a/ast/dml.go b/ast/dml.go index 0b35a70d0e..e15b5f83bb 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -60,32 +60,32 @@ type Join struct { } // Accept implements Node Accept interface. -func (nod *Join) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Join) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Join) - node, ok := nod.Left.Accept(v) + n = newNod.(*Join) + node, ok := n.Left.Accept(v) if !ok { - return nod, false + return n, false } - nod.Left = node.(ResultSetNode) - if nod.Right != nil { - node, ok = nod.Right.Accept(v) + n.Left = node.(ResultSetNode) + if n.Right != nil { + node, ok = n.Right.Accept(v) if !ok { - return nod, false + return n, false } - nod.Right = node.(ResultSetNode) + n.Right = node.(ResultSetNode) } - if nod.On != nil { - node, ok = nod.On.Accept(v) + if n.On != nil { + node, ok = n.On.Accept(v) if !ok { - return nod, false + return n, false } - nod.On = node.(*OnCondition) + n.On = node.(*OnCondition) } - return v.Leave(nod) + return v.Leave(n) } // TableName represents a table name. @@ -101,13 +101,13 @@ type TableName struct { } // Accept implements Node Accept interface. -func (nod *TableName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TableName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TableName) - return v.Leave(nod) + n = newNod.(*TableName) + return v.Leave(n) } // TableSource represents table source with a name. @@ -123,18 +123,18 @@ type TableSource struct { } // Accept implements Node Accept interface. -func (nod *TableSource) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TableSource) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TableSource) - node, ok := nod.Source.Accept(v) + n = newNod.(*TableSource) + node, ok := n.Source.Accept(v) if !ok { - return nod, false + return n, false } - nod.Source = node.(ResultSetNode) - return v.Leave(nod) + n.Source = node.(ResultSetNode) + return v.Leave(n) } // OnCondition represetns JOIN on condition. @@ -145,28 +145,28 @@ type OnCondition struct { } // Accept implements Node Accept interface. -func (nod *OnCondition) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *OnCondition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*OnCondition) - node, ok := nod.Expr.Accept(v) + n = newNod.(*OnCondition) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // SetResultFields implements ResultSet interface. -func (nod *TableSource) SetResultFields(rfs []*ResultField) { - nod.Source.SetResultFields(rfs) +func (n *TableSource) SetResultFields(rfs []*ResultField) { + n.Source.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (nod *TableSource) GetResultFields() []*ResultField { - return nod.Source.GetResultFields() +func (n *TableSource) GetResultFields() []*ResultField { + return n.Source.GetResultFields() } // SelectLockType is the lock type for SelectStmt. @@ -188,13 +188,13 @@ type WildCardField struct { } // Accept implements Node Accept interface. -func (nod *WildCardField) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *WildCardField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*WildCardField) - return v.Leave(nod) + n = newNod.(*WildCardField) + return v.Leave(n) } // SelectField represents fields in select statement. @@ -214,20 +214,20 @@ type SelectField struct { } // Accept implements Node Accept interface. -func (nod *SelectField) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SelectField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SelectField) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*SelectField) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // FieldList represents field list in select statement. @@ -238,20 +238,20 @@ type FieldList struct { } // Accept implements Node Accept interface. -func (nod *FieldList) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FieldList) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FieldList) - for i, val := range nod.Fields { + n = newNod.(*FieldList) + for i, val := range n.Fields { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Fields[i] = node.(*SelectField) + n.Fields[i] = node.(*SelectField) } - return v.Leave(nod) + return v.Leave(n) } // TableRefsClause represents table references clause in dml statement. @@ -262,18 +262,18 @@ type TableRefsClause struct { } // Accept implements Node Accept interface. -func (nod *TableRefsClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TableRefsClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TableRefsClause) - node, ok := nod.TableRefs.Accept(v) + n = newNod.(*TableRefsClause) + node, ok := n.TableRefs.Accept(v) if !ok { - return nod, false + return n, false } - nod.TableRefs = node.(*Join) - return v.Leave(nod) + n.TableRefs = node.(*Join) + return v.Leave(n) } // ByItem represents an item in order by or group by. @@ -285,18 +285,18 @@ type ByItem struct { } // Accept implements Node Accept interface. -func (nod *ByItem) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ByItem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ByItem) - node, ok := nod.Expr.Accept(v) + n = newNod.(*ByItem) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // GroupByClause represents group by clause. @@ -306,20 +306,20 @@ type GroupByClause struct { } // Accept implements Node Accept interface. -func (nod *GroupByClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *GroupByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*GroupByClause) - for i, val := range nod.Items { + n = newNod.(*GroupByClause) + for i, val := range n.Items { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Items[i] = node.(*ByItem) + n.Items[i] = node.(*ByItem) } - return v.Leave(nod) + return v.Leave(n) } // HavingClause represents having clause. @@ -329,18 +329,18 @@ type HavingClause struct { } // Accept implements Node Accept interface. -func (nod *HavingClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *HavingClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*HavingClause) - node, ok := nod.Expr.Accept(v) + n = newNod.(*HavingClause) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // OrderByClause represents order by clause. @@ -351,20 +351,20 @@ type OrderByClause struct { } // Accept implements Node Accept interface. -func (nod *OrderByClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *OrderByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*OrderByClause) - for i, val := range nod.Items { + n = newNod.(*OrderByClause) + for i, val := range n.Items { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Items[i] = node.(*ByItem) + n.Items[i] = node.(*ByItem) } - return v.Leave(nod) + return v.Leave(n) } // SelectStmt represents the select query node. @@ -394,69 +394,69 @@ type SelectStmt struct { } // Accept implements Node Accept interface. -func (nod *SelectStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SelectStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SelectStmt) + n = newNod.(*SelectStmt) - if nod.From != nil { - node, ok := nod.From.Accept(v) + if n.From != nil { + node, ok := n.From.Accept(v) if !ok { - return nod, false + return n, false } - nod.From = node.(*TableRefsClause) + n.From = node.(*TableRefsClause) } - if nod.Where != nil { - node, ok := nod.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - if nod.Fields != nil { - node, ok := nod.Fields.Accept(v) + if n.Fields != nil { + node, ok := n.Fields.Accept(v) if !ok { - return nod, false + return n, false } - nod.Fields = node.(*FieldList) + n.Fields = node.(*FieldList) } - if nod.GroupBy != nil { - node, ok := nod.GroupBy.Accept(v) + if n.GroupBy != nil { + node, ok := n.GroupBy.Accept(v) if !ok { - return nod, false + return n, false } - nod.GroupBy = node.(*GroupByClause) + n.GroupBy = node.(*GroupByClause) } - if nod.Having != nil { - node, ok := nod.Having.Accept(v) + if n.Having != nil { + node, ok := n.Having.Accept(v) if !ok { - return nod, false + return n, false } - nod.Having = node.(*HavingClause) + n.Having = node.(*HavingClause) } - if nod.OrderBy != nil { - node, ok := nod.OrderBy.Accept(v) + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) if !ok { - return nod, false + return n, false } - nod.OrderBy = node.(*OrderByClause) + n.OrderBy = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok := nod.Limit.Accept(v) + if n.Limit != nil { + node, ok := n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. @@ -468,18 +468,18 @@ type UnionClause struct { } // Accept implements Node Accept interface. -func (nod *UnionClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UnionClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UnionClause) - node, ok := nod.Select.Accept(v) + n = newNod.(*UnionClause) + node, ok := n.Select.Accept(v) if !ok { - return nod, false + return n, false } - nod.Select = node.(*SelectStmt) - return v.Leave(nod) + n.Select = node.(*SelectStmt) + return v.Leave(n) } // UnionStmt represents "union statement" @@ -495,34 +495,34 @@ type UnionStmt struct { } // Accept implements Node Accept interface. -func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UnionStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UnionStmt) - for i, val := range nod.Selects { + n = newNod.(*UnionStmt) + for i, val := range n.Selects { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Selects[i] = node.(*SelectStmt) + n.Selects[i] = node.(*SelectStmt) } - if nod.OrderBy != nil { - node, ok := nod.OrderBy.Accept(v) + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) if !ok { - return nod, false + return n, false } - nod.OrderBy = node.(*OrderByClause) + n.OrderBy = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok := nod.Limit.Accept(v) + if n.Limit != nil { + node, ok := n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // Assignment is the expression for assignment, like a = 1. @@ -535,23 +535,23 @@ type Assignment struct { } // Accept implements Node Accept interface. -func (nod *Assignment) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Assignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Assignment) - node, ok := nod.Column.Accept(v) + n = newNod.(*Assignment) + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) - node, ok = nod.Expr.Accept(v) + n.Column = node.(*ColumnName) + node, ok = n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // Priority const values. @@ -579,57 +579,57 @@ type InsertStmt struct { } // Accept implements Node Accept interface. -func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *InsertStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*InsertStmt) + n = newNod.(*InsertStmt) - if nod.Select != nil { - node, ok := nod.Select.Accept(v) + if n.Select != nil { + node, ok := n.Select.Accept(v) if !ok { - return nod, false + return n, false } - nod.Select = node.(ResultSetNode) + n.Select = node.(ResultSetNode) } - node, ok := nod.Table.Accept(v) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableRefsClause) + n.Table = node.(*TableRefsClause) - for i, val := range nod.Columns { + for i, val := range n.Columns { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Columns[i] = node.(*ColumnName) + n.Columns[i] = node.(*ColumnName) } - for i, list := range nod.Lists { + for i, list := range n.Lists { for j, val := range list { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Lists[i][j] = node.(ExprNode) + n.Lists[i][j] = node.(ExprNode) } } - for i, val := range nod.Setlist { + for i, val := range n.Setlist { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Setlist[i] = node.(*Assignment) + n.Setlist[i] = node.(*Assignment) } - for i, val := range nod.OnDuplicate { + for i, val := range n.OnDuplicate { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.OnDuplicate[i] = node.(*Assignment) + n.OnDuplicate[i] = node.(*Assignment) } - return v.Leave(nod) + return v.Leave(n) } // DeleteStmt is a statement to delete rows from table. @@ -652,49 +652,49 @@ type DeleteStmt struct { } // Accept implements Node Accept interface. -func (nod *DeleteStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DeleteStmt) + n = newNod.(*DeleteStmt) - node, ok := nod.TableRefs.Accept(v) + node, ok := n.TableRefs.Accept(v) if !ok { - return nod, false + return n, false } - nod.TableRefs = node.(*TableRefsClause) + n.TableRefs = node.(*TableRefsClause) - for i, val := range nod.Tables { + for i, val := range n.Tables { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - if nod.Where != nil { - node, ok = nod.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - if nod.Order != nil { - node, ok = nod.Order.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return nod, false + return n, false } - nod.Order = node.(*OrderByClause) + n.Order = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok = nod.Limit.Accept(v) + if n.Limit != nil { + node, ok = n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // UpdateStmt is a statement to update columns of existing rows in tables with new values. @@ -713,46 +713,46 @@ type UpdateStmt struct { } // Accept implements Node Accept interface. -func (nod *UpdateStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UpdateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UpdateStmt) - node, ok := nod.TableRefs.Accept(v) + n = newNod.(*UpdateStmt) + node, ok := n.TableRefs.Accept(v) if !ok { - return nod, false + return n, false } - nod.TableRefs = node.(*TableRefsClause) - for i, val := range nod.List { + n.TableRefs = node.(*TableRefsClause) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.List[i] = node.(*Assignment) + n.List[i] = node.(*Assignment) } - if nod.Where != nil { - node, ok = nod.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - if nod.Order != nil { - node, ok = nod.Order.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return nod, false + return n, false } - nod.Order = node.(*OrderByClause) + n.Order = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok = nod.Limit.Accept(v) + if n.Limit != nil { + node, ok = n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // Limit is the limit clause. @@ -764,11 +764,11 @@ type Limit struct { } // Accept implements Node Accept interface. -func (nod *Limit) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Limit) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Limit) - return v.Leave(nod) + n = newNod.(*Limit) + return v.Leave(n) } diff --git a/ast/expressions.go b/ast/expressions.go index cddb7f45d5..c3278bd331 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -91,18 +91,18 @@ func NewValueExpr(value interface{}) *ValueExpr { } // IsStatic implements ExprNode interface. -func (nod *ValueExpr) IsStatic() bool { +func (n *ValueExpr) IsStatic() bool { return true } // Accept implements Node interface. -func (nod *ValueExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ValueExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ValueExpr) - return v.Leave(nod) + n = newNod.(*ValueExpr) + return v.Leave(n) } // BetweenExpr is for "between and" or "not between and" expression. @@ -119,37 +119,37 @@ type BetweenExpr struct { } // Accept implements Node interface. -func (nod *BetweenExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *BetweenExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*BetweenExpr) + n = newNod.(*BetweenExpr) - node, ok := nod.Expr.Accept(v) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = nod.Left.Accept(v) + node, ok = n.Left.Accept(v) if !ok { - return nod, false + return n, false } - nod.Left = node.(ExprNode) + n.Left = node.(ExprNode) - node, ok = nod.Right.Accept(v) + node, ok = n.Right.Accept(v) if !ok { - return nod, false + return n, false } - nod.Right = node.(ExprNode) + n.Right = node.(ExprNode) - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *BetweenExpr) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Left.IsStatic() && nod.Right.IsStatic() +func (n *BetweenExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Left.IsStatic() && n.Right.IsStatic() } // BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. @@ -164,31 +164,31 @@ type BinaryOperationExpr struct { } // Accept implements Node interface. -func (nod *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*BinaryOperationExpr) + n = newNod.(*BinaryOperationExpr) - node, ok := nod.L.Accept(v) + node, ok := n.L.Accept(v) if !ok { - return nod, false + return n, false } - nod.L = node.(ExprNode) + n.L = node.(ExprNode) - node, ok = nod.R.Accept(v) + node, ok = n.R.Accept(v) if !ok { - return nod, false + return n, false } - nod.R = node.(ExprNode) + n.R = node.(ExprNode) - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *BinaryOperationExpr) IsStatic() bool { - return nod.L.IsStatic() && nod.R.IsStatic() +func (n *BinaryOperationExpr) IsStatic() bool { + return n.L.IsStatic() && n.R.IsStatic() } // WhenClause is the when clause in Case expression for "when condition then result". @@ -201,29 +201,29 @@ type WhenClause struct { } // Accept implements Node Accept interface. -func (nod *WhenClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *WhenClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*WhenClause) - node, ok := nod.Expr.Accept(v) + n = newNod.(*WhenClause) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = nod.Result.Accept(v) + node, ok = n.Result.Accept(v) if !ok { - return nod, false + return n, false } - nod.Result = node.(ExprNode) - return v.Leave(nod) + n.Result = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *WhenClause) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Result.IsStatic() +func (n *WhenClause) IsStatic() bool { + return n.Expr.IsStatic() && n.Result.IsStatic() } // CaseExpr is the case expression. @@ -238,47 +238,47 @@ type CaseExpr struct { } // Accept implements Node Accept interface. -func (nod *CaseExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CaseExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CaseExpr) - if nod.Value != nil { - node, ok := nod.Value.Accept(v) + n = newNod.(*CaseExpr) + if n.Value != nil { + node, ok := n.Value.Accept(v) if !ok { - return nod, false + return n, false } - nod.Value = node.(ExprNode) + n.Value = node.(ExprNode) } - for i, val := range nod.WhenClauses { + for i, val := range n.WhenClauses { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.WhenClauses[i] = node.(*WhenClause) + n.WhenClauses[i] = node.(*WhenClause) } - if nod.ElseClause != nil { - node, ok := nod.ElseClause.Accept(v) + if n.ElseClause != nil { + node, ok := n.ElseClause.Accept(v) if !ok { - return nod, false + return n, false } - nod.ElseClause = node.(ExprNode) + n.ElseClause = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *CaseExpr) IsStatic() bool { - if nod.Value != nil && !nod.Value.IsStatic() { +func (n *CaseExpr) IsStatic() bool { + if n.Value != nil && !n.Value.IsStatic() { return false } - for _, w := range nod.WhenClauses { + for _, w := range n.WhenClauses { if !w.IsStatic() { return false } } - if nod.ElseClause != nil && !nod.ElseClause.IsStatic() { + if n.ElseClause != nil && !n.ElseClause.IsStatic() { return false } return true @@ -292,28 +292,28 @@ type SubqueryExpr struct { } // Accept implements Node Accept interface. -func (nod *SubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SubqueryExpr) - node, ok := nod.Query.Accept(v) + n = newNod.(*SubqueryExpr) + node, ok := n.Query.Accept(v) if !ok { - return nod, false + return n, false } - nod.Query = node.(ResultSetNode) - return v.Leave(nod) + n.Query = node.(ResultSetNode) + return v.Leave(n) } // SetResultFields implements ResultSet interface. -func (nod *SubqueryExpr) SetResultFields(rfs []*ResultField) { - nod.Query.SetResultFields(rfs) +func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) { + n.Query.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (nod *SubqueryExpr) GetResultFields() []*ResultField { - return nod.Query.GetResultFields() +func (n *SubqueryExpr) GetResultFields() []*ResultField { + return n.Query.GetResultFields() } // CompareSubqueryExpr is the expression for "expr cmp (select ...)". @@ -333,23 +333,23 @@ type CompareSubqueryExpr struct { } // Accept implements Node Accept interface. -func (nod *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CompareSubqueryExpr) - node, ok := nod.L.Accept(v) + n = newNod.(*CompareSubqueryExpr) + node, ok := n.L.Accept(v) if !ok { - return nod, false + return n, false } - nod.L = node.(ExprNode) - node, ok = nod.R.Accept(v) + n.L = node.(ExprNode) + node, ok = n.R.Accept(v) if !ok { - return nod, false + return n, false } - nod.R = node.(*SubqueryExpr) - return v.Leave(nod) + n.R = node.(*SubqueryExpr) + return v.Leave(n) } // ColumnName represents column name. @@ -365,13 +365,13 @@ type ColumnName struct { } // Accept implements Node Accept interface. -func (nod *ColumnName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnName) - return v.Leave(nod) + n = newNod.(*ColumnName) + return v.Leave(n) } // ColumnNameExpr represents a column name expression. @@ -383,18 +383,18 @@ type ColumnNameExpr struct { } // Accept implements Node Accept interface. -func (nod *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnNameExpr) - node, ok := nod.Name.Accept(v) + n = newNod.(*ColumnNameExpr) + node, ok := n.Name.Accept(v) if !ok { - return nod, false + return n, false } - nod.Name = node.(*ColumnName) - return v.Leave(nod) + n.Name = node.(*ColumnName) + return v.Leave(n) } // DefaultExpr is the default expression using default value for a column. @@ -405,20 +405,20 @@ type DefaultExpr struct { } // Accept implements Node Accept interface. -func (nod *DefaultExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DefaultExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DefaultExpr) - if nod.Name != nil { - node, ok := nod.Name.Accept(v) + n = newNod.(*DefaultExpr) + if n.Name != nil { + node, ok := n.Name.Accept(v) if !ok { - return nod, false + return n, false } - nod.Name = node.(*ColumnName) + n.Name = node.(*ColumnName) } - return v.Leave(nod) + return v.Leave(n) } // IdentifierExpr represents an identifier expression. @@ -429,13 +429,13 @@ type IdentifierExpr struct { } // Accept implements Node Accept interface. -func (nod *IdentifierExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IdentifierExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IdentifierExpr) - return v.Leave(nod) + n = newNod.(*IdentifierExpr) + return v.Leave(n) } // ExistsSubqueryExpr is the expression for "exists (select ...)". @@ -447,18 +447,18 @@ type ExistsSubqueryExpr struct { } // Accept implements Node Accept interface. -func (nod *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ExistsSubqueryExpr) - node, ok := nod.Sel.Accept(v) + n = newNod.(*ExistsSubqueryExpr) + node, ok := n.Sel.Accept(v) if !ok { - return nod, false + return n, false } - nod.Sel = node.(*SubqueryExpr) - return v.Leave(nod) + n.Sel = node.(*SubqueryExpr) + return v.Leave(n) } // PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)". @@ -475,32 +475,32 @@ type PatternInExpr struct { } // Accept implements Node Accept interface. -func (nod *PatternInExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PatternInExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PatternInExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*PatternInExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - for i, val := range nod.List { + n.Expr = node.(ExprNode) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.List[i] = node.(ExprNode) + n.List[i] = node.(ExprNode) } - if nod.Sel != nil { - node, ok = nod.Sel.Accept(v) + if n.Sel != nil { + node, ok = n.Sel.Accept(v) if !ok { - return nod, false + return n, false } - nod.Sel = node.(*SubqueryExpr) + n.Sel = node.(*SubqueryExpr) } - return v.Leave(nod) + return v.Leave(n) } // IsNullExpr is the expression for null check. @@ -513,23 +513,23 @@ type IsNullExpr struct { } // Accept implements Node Accept interface. -func (nod *IsNullExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IsNullExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IsNullExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*IsNullExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *IsNullExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *IsNullExpr) IsStatic() bool { + return n.Expr.IsStatic() } // IsTruthExpr is the expression for true/false check. @@ -544,23 +544,23 @@ type IsTruthExpr struct { } // Accept implements Node Accept interface. -func (nod *IsTruthExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IsTruthExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*IsTruthExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *IsTruthExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *IsTruthExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PatternLikeExpr is the expression for like operator, e.g, expr like "%123%" @@ -577,32 +577,32 @@ type PatternLikeExpr struct { } // Accept implements Node Accept interface. -func (nod *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PatternLikeExpr) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*PatternLikeExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - if nod.Pattern != nil { - node, ok := nod.Pattern.Accept(v) + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pattern = node.(ExprNode) + n.Pattern = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *PatternLikeExpr) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Pattern.IsStatic() +func (n *PatternLikeExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // ParamMarkerExpr expresion holds a place for another expression. @@ -613,13 +613,13 @@ type ParamMarkerExpr struct { } // Accept implements Node Accept interface. -func (nod *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ParamMarkerExpr) - return v.Leave(nod) + n = newNod.(*ParamMarkerExpr) + return v.Leave(n) } // ParenthesesExpr is the parentheses expression. @@ -630,25 +630,25 @@ type ParenthesesExpr struct { } // Accept implements Node Accept interface. -func (nod *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ParenthesesExpr) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*ParenthesesExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *ParenthesesExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *ParenthesesExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PositionExpr is the expression for order by and group by position. @@ -663,18 +663,18 @@ type PositionExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (nod *PositionExpr) IsStatic() bool { +func (n *PositionExpr) IsStatic() bool { return true } // Accept implements Node Accept interface. -func (nod *PositionExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PositionExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PositionExpr) - return v.Leave(nod) + n = newNod.(*PositionExpr) + return v.Leave(n) } // PatternRegexpExpr is the pattern expression for pattern match. @@ -689,28 +689,28 @@ type PatternRegexpExpr struct { } // Accept implements Node Accept interface. -func (nod *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PatternRegexpExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*PatternRegexpExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - node, ok = nod.Pattern.Accept(v) + n.Expr = node.(ExprNode) + node, ok = n.Pattern.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pattern = node.(ExprNode) - return v.Leave(nod) + n.Pattern = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *PatternRegexpExpr) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Pattern.IsStatic() +func (n *PatternRegexpExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // RowExpr is the expression for row constructor. @@ -722,25 +722,25 @@ type RowExpr struct { } // Accept implements Node Accept interface. -func (nod *RowExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *RowExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*RowExpr) - for i, val := range nod.Values { + n = newNod.(*RowExpr) + for i, val := range n.Values { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Values[i] = node.(ExprNode) + n.Values[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *RowExpr) IsStatic() bool { - for _, v := range nod.Values { +func (n *RowExpr) IsStatic() bool { + for _, v := range n.Values { if !v.IsStatic() { return false } @@ -758,23 +758,23 @@ type UnaryOperationExpr struct { } // Accept implements Node Accept interface. -func (nod *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UnaryOperationExpr) - node, ok := nod.V.Accept(v) + n = newNod.(*UnaryOperationExpr) + node, ok := n.V.Accept(v) if !ok { - return nod, false + return n, false } - nod.V = node.(ExprNode) - return v.Leave(nod) + n.V = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *UnaryOperationExpr) IsStatic() bool { - return nod.V.IsStatic() +func (n *UnaryOperationExpr) IsStatic() bool { + return n.V.IsStatic() } // ValuesExpr is the expression used in INSERT VALUES @@ -785,18 +785,18 @@ type ValuesExpr struct { } // Accept implements Node Accept interface. -func (nod *ValuesExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ValuesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ValuesExpr) - node, ok := nod.Column.Accept(v) + n = newNod.(*ValuesExpr) + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) - return v.Leave(nod) + n.Column = node.(*ColumnName) + return v.Leave(n) } // VariableExpr is the expression for variable. @@ -811,11 +811,11 @@ type VariableExpr struct { } // Accept implements Node Accept interface. -func (nod *VariableExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *VariableExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*VariableExpr) - return v.Leave(nod) + n = newNod.(*VariableExpr) + return v.Leave(n) } diff --git a/ast/functions.go b/ast/functions.go index 8d688f9eee..7faab4c6e1 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -45,30 +45,30 @@ type FuncCallExpr struct { } // Accept implements Node interface. -func (nod *FuncCallExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncCallExpr) - for i, val := range nod.Args { + n = newNod.(*FuncCallExpr) + for i, val := range n.Args { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Args[i] = node.(ExprNode) + n.Args[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(nod.FnName)] +func (n *FuncCallExpr) IsStatic() bool { + v := builtin.Funcs[strings.ToLower(n.FnName)] if v.F == nil || !v.IsStatic { return false } - for _, v := range nod.Args { + for _, v := range n.Args { if !v.IsStatic() { return false } @@ -86,23 +86,23 @@ type FuncExtractExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncExtractExpr) - node, ok := nod.Date.Accept(v) + n = newNod.(*FuncExtractExpr) + node, ok := n.Date.Accept(v) if !ok { - return nod, false + return n, false } - nod.Date = node.(ExprNode) - return v.Leave(nod) + n.Date = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncExtractExpr) IsStatic() bool { - return nod.Date.IsStatic() +func (n *FuncExtractExpr) IsStatic() bool { + return n.Date.IsStatic() } // FuncConvertExpr provides a way to convert data between different character sets. @@ -116,23 +116,23 @@ type FuncConvertExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncConvertExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *FuncConvertExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (nod *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncConvertExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*FuncConvertExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // CastFunctionType is the type for cast function. @@ -158,23 +158,23 @@ type FuncCastExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncCastExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *FuncCastExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (nod *FuncCastExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncCastExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*FuncCastExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // FuncSubstringExpr returns the substring as specified. @@ -188,35 +188,35 @@ type FuncSubstringExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncSubstringExpr) - node, ok := nod.StrExpr.Accept(v) + n = newNod.(*FuncSubstringExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return nod, false + return n, false } - nod.StrExpr = node.(ExprNode) - node, ok = nod.Pos.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pos = node.(ExprNode) - if nod.Len != nil { - node, ok = nod.Len.Accept(v) + n.Pos = node.(ExprNode) + if n.Len != nil { + node, ok = n.Len.Accept(v) if !ok { - return nod, false + return n, false } - nod.Len = node.(ExprNode) + n.Len = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncSubstringExpr) IsStatic() bool { - return nod.StrExpr.IsStatic() && nod.Pos.IsStatic() && nod.Len.IsStatic() +func (n *FuncSubstringExpr) IsStatic() bool { + return n.StrExpr.IsStatic() && n.Pos.IsStatic() && n.Len.IsStatic() } // FuncSubstringIndexExpr returns the substring as specified. @@ -230,28 +230,28 @@ type FuncSubstringIndexExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncSubstringIndexExpr) - node, ok := nod.StrExpr.Accept(v) + n = newNod.(*FuncSubstringIndexExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return nod, false + return n, false } - nod.StrExpr = node.(ExprNode) - node, ok = nod.Delim.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Delim.Accept(v) if !ok { - return nod, false + return n, false } - nod.Delim = node.(ExprNode) - node, ok = nod.Count.Accept(v) + n.Delim = node.(ExprNode) + node, ok = n.Count.Accept(v) if !ok { - return nod, false + return n, false } - nod.Count = node.(ExprNode) - return v.Leave(nod) + n.Count = node.(ExprNode) + return v.Leave(n) } // FuncLocateExpr returns the position of the first occurrence of substring. @@ -265,28 +265,28 @@ type FuncLocateExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncLocateExpr) - node, ok := nod.Str.Accept(v) + n = newNod.(*FuncLocateExpr) + node, ok := n.Str.Accept(v) if !ok { - return nod, false + return n, false } - nod.Str = node.(ExprNode) - node, ok = nod.SubStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.SubStr.Accept(v) if !ok { - return nod, false + return n, false } - nod.SubStr = node.(ExprNode) - node, ok = nod.Pos.Accept(v) + n.SubStr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pos = node.(ExprNode) - return v.Leave(nod) + n.Pos = node.(ExprNode) + return v.Leave(n) } // TrimDirectionType is the type for trim direction. @@ -314,28 +314,28 @@ type FuncTrimExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncTrimExpr) - node, ok := nod.Str.Accept(v) + n = newNod.(*FuncTrimExpr) + node, ok := n.Str.Accept(v) if !ok { - return nod, false + return n, false } - nod.Str = node.(ExprNode) - node, ok = nod.RemStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.RemStr.Accept(v) if !ok { - return nod, false + return n, false } - nod.RemStr = node.(ExprNode) - return v.Leave(nod) + n.RemStr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncTrimExpr) IsStatic() bool { - return nod.Str.IsStatic() && nod.RemStr.IsStatic() +func (n *FuncTrimExpr) IsStatic() bool { + return n.Str.IsStatic() && n.RemStr.IsStatic() } // DateArithType is type for DateArith option. @@ -361,27 +361,27 @@ type FuncDateArithExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncDateArithExpr) - if nod.Date != nil { - node, ok := nod.Date.Accept(v) + n = newNod.(*FuncDateArithExpr) + if n.Date != nil { + node, ok := n.Date.Accept(v) if !ok { - return nod, false + return n, false } - nod.Date = node.(ExprNode) + n.Date = node.(ExprNode) } - if nod.Interval != nil { - node, ok := nod.Date.Accept(v) + if n.Interval != nil { + node, ok := n.Date.Accept(v) if !ok { - return nod, false + return n, false } - nod.Date = node.(ExprNode) + n.Date = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // AggregateFuncExpr represents aggregate function expression. @@ -398,18 +398,18 @@ type AggregateFuncExpr struct { } // Accept implements Node Accept interface. -func (nod *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*AggregateFuncExpr) - for i, val := range nod.Args { + n = newNod.(*AggregateFuncExpr) + for i, val := range n.Args { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Args[i] = node.(ExprNode) + n.Args[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } diff --git a/ast/misc.go b/ast/misc.go index b0f9e845fc..4f287b4680 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -61,18 +61,18 @@ type ExplainStmt struct { } // Accept implements Node Accept interface. -func (nod *ExplainStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ExplainStmt) - node, ok := nod.Stmt.Accept(v) + n = newNod.(*ExplainStmt) + node, ok := n.Stmt.Accept(v) if !ok { - return nod, false + return n, false } - nod.Stmt = node.(DMLNode) - return v.Leave(nod) + n.Stmt = node.(DMLNode) + return v.Leave(n) } // PrepareStmt is a statement to prepares a SQL statement which contains placeholders, @@ -89,18 +89,18 @@ type PrepareStmt struct { } // Accept implements Node Accept interface. -func (nod *PrepareStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PrepareStmt) - node, ok := nod.SQLVar.Accept(v) + n = newNod.(*PrepareStmt) + node, ok := n.SQLVar.Accept(v) if !ok { - return nod, false + return n, false } - nod.SQLVar = node.(*VariableExpr) - return v.Leave(nod) + n.SQLVar = node.(*VariableExpr) + return v.Leave(n) } // DeallocateStmt is a statement to release PreparedStmt. @@ -113,13 +113,13 @@ type DeallocateStmt struct { } // Accept implements Node Accept interface. -func (nod *DeallocateStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DeallocateStmt) - return v.Leave(nod) + n = newNod.(*DeallocateStmt) + return v.Leave(n) } // ExecuteStmt is a statement to execute PreparedStmt. @@ -133,20 +133,20 @@ type ExecuteStmt struct { } // Accept implements Node Accept interface. -func (nod *ExecuteStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ExecuteStmt) - for i, val := range nod.UsingVars { + n = newNod.(*ExecuteStmt) + for i, val := range n.UsingVars { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.UsingVars[i] = node.(ExprNode) + n.UsingVars[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // ShowStmtType is the type for SHOW statement. @@ -187,41 +187,41 @@ type ShowStmt struct { } // Accept implements Node Accept interface. -func (nod *ShowStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ShowStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ShowStmt) - if nod.Table != nil { - node, ok := nod.Table.Accept(v) + n = newNod.(*ShowStmt) + if n.Table != nil { + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) + n.Table = node.(*TableName) } - if nod.Column != nil { - node, ok := nod.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) + n.Column = node.(*ColumnName) } - if nod.Pattern != nil { - node, ok := nod.Pattern.Accept(v) + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pattern = node.(*PatternLikeExpr) + n.Pattern = node.(*PatternLikeExpr) } - if nod.Where != nil { - node, ok := nod.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // BeginStmt is a statement to start a new transaction. @@ -231,13 +231,13 @@ type BeginStmt struct { } // Accept implements Node Accept interface. -func (nod *BeginStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *BeginStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*BeginStmt) - return v.Leave(nod) + n = newNod.(*BeginStmt) + return v.Leave(n) } // CommitStmt is a statement to commit the current transaction. @@ -247,13 +247,13 @@ type CommitStmt struct { } // Accept implements Node Accept interface. -func (nod *CommitStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CommitStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CommitStmt) - return v.Leave(nod) + n = newNod.(*CommitStmt) + return v.Leave(n) } // RollbackStmt is a statement to roll back the current transaction. @@ -263,13 +263,13 @@ type RollbackStmt struct { } // Accept implements Node Accept interface. -func (nod *RollbackStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*RollbackStmt) - return v.Leave(nod) + n = newNod.(*RollbackStmt) + return v.Leave(n) } // UseStmt is a statement to use the DBName database as the current database. @@ -281,13 +281,13 @@ type UseStmt struct { } // Accept implements Node Accept interface. -func (nod *UseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UseStmt) - return v.Leave(nod) + n = newNod.(*UseStmt) + return v.Leave(n) } // VariableAssignment is a variable assignment struct. @@ -300,18 +300,18 @@ type VariableAssignment struct { } // Accept implements Node interface. -func (nod *VariableAssignment) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*VariableAssignment) - node, ok := nod.Value.Accept(v) + n = newNod.(*VariableAssignment) + node, ok := n.Value.Accept(v) if !ok { - return nod, false + return n, false } - nod.Value = node.(ExprNode) - return v.Leave(nod) + n.Value = node.(ExprNode) + return v.Leave(n) } // SetStmt is the statement to set variables. @@ -322,20 +322,20 @@ type SetStmt struct { } // Accept implements Node Accept interface. -func (nod *SetStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SetStmt) - for i, val := range nod.Variables { + n = newNod.(*SetStmt) + for i, val := range n.Variables { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Variables[i] = node.(*VariableAssignment) + n.Variables[i] = node.(*VariableAssignment) } - return v.Leave(nod) + return v.Leave(n) } // SetCharsetStmt is a statement to assign values to character and collation variables. @@ -348,13 +348,13 @@ type SetCharsetStmt struct { } // Accept implements Node Accept interface. -func (nod *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SetCharsetStmt) - return v.Leave(nod) + n = newNod.(*SetCharsetStmt) + return v.Leave(n) } // SetPwdStmt is a statement to assign a password to user account. @@ -367,13 +367,13 @@ type SetPwdStmt struct { } // Accept implements Node Accept interface. -func (nod *SetPwdStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SetPwdStmt) - return v.Leave(nod) + n = newNod.(*SetPwdStmt) + return v.Leave(n) } // UserSpec is used for parsing create user statement. @@ -392,13 +392,13 @@ type CreateUserStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateUserStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateUserStmt) - return v.Leave(nod) + n = newNod.(*CreateUserStmt) + return v.Leave(n) } // DoStmt is the struct for DO statement. @@ -409,20 +409,20 @@ type DoStmt struct { } // Accept implements Node Accept interface. -func (nod *DoStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DoStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DoStmt) - for i, val := range nod.Exprs { + n = newNod.(*DoStmt) + for i, val := range n.Exprs { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Exprs[i] = node.(ExprNode) + n.Exprs[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // PrivElem is the privilege type and optional column list. @@ -433,20 +433,20 @@ type PrivElem struct { } // Accept implements Node Accept interface. -func (nod *PrivElem) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PrivElem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PrivElem) - for i, val := range nod.Cols { + n = newNod.(*PrivElem) + for i, val := range n.Cols { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Cols[i] = node.(*ColumnName) + n.Cols[i] = node.(*ColumnName) } - return v.Leave(nod) + return v.Leave(n) } // ObjectTypeType is the type for object type. @@ -491,18 +491,18 @@ type GrantStmt struct { } // Accept implements Node Accept interface. -func (nod *GrantStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *GrantStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*GrantStmt) - for i, val := range nod.Privs { + n = newNod.(*GrantStmt) + for i, val := range n.Privs { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Privs[i] = node.(*PrivElem) + n.Privs[i] = node.(*PrivElem) } - return v.Leave(nod) + return v.Leave(n) }