From 18d8ea112c07154b7f5e377d7fdd7036eef91970 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Sun, 25 Oct 2015 18:36:52 +0800 Subject: [PATCH 01/27] ast, optimiser: change ast visitor API and implement cloner and binder --- ast/ast.go | 3 +- ast/cloner.go | 168 ++++++++++ ast/cloner_test.go | 40 +++ ast/ddl.go | 302 +++++++++-------- ast/dml.go | 619 ++++++++++++++++++++++------------- ast/expressions.go | 484 +++++++++++++++------------ ast/functions.go | 196 ++++++----- ast/misc.go | 208 +++++++----- ast/parser/parser.y | 178 +++++----- ast/parser/parser_test.go | 4 +- ast/parser/yy_parser.go | 19 ++ optimizer/infobinder.go | 441 +++++++++++++++++++++++++ optimizer/infobinder_test.go | 69 ++++ 13 files changed, 1882 insertions(+), 849 deletions(-) create mode 100644 ast/cloner.go create mode 100644 ast/cloner_test.go create mode 100644 ast/parser/yy_parser.go create mode 100644 optimizer/infobinder.go create mode 100644 optimizer/infobinder_test.go diff --git a/ast/ast.go b/ast/ast.go index 0e9ab19893..dabc4fa912 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -95,8 +95,7 @@ type Visitor interface { // VisitEnter is called before children nodes is visited. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. - // ok returns false to stop visiting. - Enter(n Node) (skipChildren bool, ok bool) + Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) diff --git a/ast/cloner.go b/ast/cloner.go new file mode 100644 index 0000000000..f08ef09476 --- /dev/null +++ b/ast/cloner.go @@ -0,0 +1,168 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +// Cloner is a ast visitor that clones a node. +type Cloner struct { +} + +// Enter implements Visitor Enter interface. +func (c *Cloner) Enter(node Node) (Node, bool) { + return cloneStruct(node), false +} + +// Leave implements Visitor Leave interface. +func (c *Cloner) Leave(in Node) (out Node, ok bool) { + return in, true +} + +// cloneStruct clone a node's struct value, if the struct has slice +// the cloned value should make a new slice and copy old slice to new slice. +func cloneStruct(in Node) (out Node) { + switch v := in.(type) { + case *ValueExpr: + nv := *v + out = &nv + case *BetweenExpr: + nv := *v + out = &nv + case *BinaryOperationExpr: + nv := *v + out = &nv + case *WhenClause: + nv := *v + out = &nv + case *CaseExpr: + nv := *v + nv.WhenClauses = make([]*WhenClause, len(v.WhenClauses)) + copy(nv.WhenClauses, v.WhenClauses) + out = &nv + case *SubqueryExpr: + nv := *v + out = &nv + case *CompareSubqueryExpr: + nv := *v + out = &nv + case *ColumnName: + nv := *v + out = &nv + case *DefaultExpr: + nv := *v + out = &nv + case *IdentifierExpr: + nv := *v + out = &nv + case *ExistsSubqueryExpr: + nv := *v + out = &nv + case *PatternInExpr: + nv := *v + nv.List = make([]ExprNode, len(v.List)) + copy(nv.List, v.List) + out = &nv + case *IsNullExpr: + nv := *v + out = &nv + case *IsTruthExpr: + nv := *v + out = &nv + case *PatternLikeExpr: + nv := *v + out = &nv + case *ParamMarkerExpr: + nv := *v + out = &nv + case *ParenthesesExpr: + nv := *v + out = &nv + case *PositionExpr: + nv := *v + out = &nv + case *PatternRegexpExpr: + nv := *v + out = &nv + case *RowExpr: + nv := *v + nv.Values = make([]ExprNode, len(v.Values)) + copy(nv.Values, v.Values) + out = &nv + case *UnaryOperationExpr: + nv := *v + out = &nv + case *ValuesExpr: + nv := *v + out = &nv + case *VariableExpr: + nv := *v + out = &nv + case *Join: + nv := *v + out = &nv + case *TableName: + nv := *v + out = &nv + case *TableSource: + nv := *v + out = &nv + case *OnCondition: + nv := *v + out = &nv + case *WildCardField: + nv := *v + out = &nv + case *SelectField: + nv := *v + out = &nv + case *FieldList: + nv := *v + nv.Fields = make([]*SelectField, len(v.Fields)) + copy(nv.Fields, v.Fields) + out = &nv + case *TableRefsClause: + nv := *v + out = &nv + case *ByItem: + nv := *v + out = &nv + case *GroupByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *HavingClause: + nv := *v + out = &nv + case *OrderByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *SelectStmt: + nv := *v + out = &nv + case *UnionClause: + nv := *v + out = &nv + case *UnionStmt: + nv := *v + nv.Selects = make([]*SelectStmt, len(v.Selects)) + copy(nv.Selects, v.Selects) + out = &nv + default: + // We currently only handle expression and select statement. + // Will add more when we need to. + panic("unknown ast Node type") + } + return +} diff --git a/ast/cloner_test.go b/ast/cloner_test.go new file mode 100644 index 0000000000..186ee824a7 --- /dev/null +++ b/ast/cloner_test.go @@ -0,0 +1,40 @@ +package ast + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/parser/opcode" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testClonerSuite{}) + +type testClonerSuite struct { +} + +func (ts *testClonerSuite) TestCloner(c *C) { + cloner := &Cloner{} + + a := &UnaryOperationExpr{ + Op: opcode.Not, + V: &UnaryOperationExpr{V: &ValueExpr{Val: true}}, + } + + b, ok := a.Accept(cloner) + c.Assert(ok, IsTrue) + a1 := a.V + b1 := b.(*UnaryOperationExpr).V + c.Assert(a1, Not(Equals), b1) + a2 := a1.(*UnaryOperationExpr).V + b2 := b1.(*UnaryOperationExpr).V + c.Assert(a2, Not(Equals), b2) + a3 := a2.(*ValueExpr) + b3 := b2.(*ValueExpr) + c.Assert(a3, Not(Equals), b3) + c.Assert(a3.Val, Equals, true) + c.Assert(b3.Val, Equals, true) +} diff --git a/ast/ddl.go b/ast/ddl.go index 0e43b69552..37a5d99c88 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -70,11 +70,13 @@ type CreateDatabaseStmt struct { } // Accept implements Node Accept interface. -func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (nod *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cd) + nod = newNod.(*CreateDatabaseStmt) + return v.Leave(nod) } // DropDatabaseStmt is a statement to drop a database and all tables in the database. @@ -87,11 +89,13 @@ type DropDatabaseStmt struct { } // Accept implements Node Accept interface. -func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dd); skipChildren { - return dd, ok +func (nod *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(dd) + nod = newNod.(*DropDatabaseStmt) + return v.Leave(nod) } // IndexColName is used for parsing index column name from SQL. @@ -103,16 +107,18 @@ type IndexColName struct { } // Accept implements Node Accept interface. -func (ic *IndexColName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ic); skipChildren { - return ic, ok +func (nod *IndexColName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ic.Column.Accept(v) + nod = newNod.(*IndexColName) + node, ok := nod.Column.Accept(v) if !ok { - return ic, false + return nod, false } - ic.Column = node.(*ColumnName) - return v.Leave(ic) + nod.Column = node.(*ColumnName) + return v.Leave(nod) } // ReferenceDef is used for parsing foreign key reference option from SQL. @@ -125,23 +131,25 @@ type ReferenceDef struct { } // Accept implements Node Accept interface. -func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rd); skipChildren { - return rd, ok +func (nod *ReferenceDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := rd.Table.Accept(v) + nod = newNod.(*ReferenceDef) + node, ok := nod.Table.Accept(v) if !ok { - return rd, false + return nod, false } - rd.Table = node.(*TableName) - for i, val := range rd.IndexColNames { + nod.Table = node.(*TableName) + for i, val := range nod.IndexColNames { node, ok = val.Accept(v) if !ok { - return rd, false + return nod, false } - rd.IndexColNames[i] = node.(*IndexColName) + nod.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(rd) + return v.Leave(nod) } // ColumnOptionType is the type for ColumnOption. @@ -175,18 +183,20 @@ type ColumnOption struct { } // Accept implements Node Accept interface. -func (co *ColumnOption) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(co); skipChildren { - return co, ok +func (nod *ColumnOption) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if co.Expr != nil { - node, ok := co.Expr.Accept(v) + nod = newNod.(*ColumnOption) + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) if !ok { - return co, false + return nod, false } - co.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) } - return v.Leave(co) + return v.Leave(nod) } // ConstraintType is the type for Constraint. @@ -220,25 +230,27 @@ type Constraint struct { } // Accept implements Node Accept interface. -func (tc *Constraint) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tc); skipChildren { - return tc, ok +func (nod *Constraint) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range tc.Keys { + nod = newNod.(*Constraint) + for i, val := range nod.Keys { node, ok := val.Accept(v) if !ok { - return tc, false + return nod, false } - tc.Keys[i] = node.(*IndexColName) + nod.Keys[i] = node.(*IndexColName) } - if tc.Refer != nil { - node, ok := tc.Refer.Accept(v) + if nod.Refer != nil { + node, ok := nod.Refer.Accept(v) if !ok { - return tc, false + return nod, false } - tc.Refer = node.(*ReferenceDef) + nod.Refer = node.(*ReferenceDef) } - return v.Leave(tc) + return v.Leave(nod) } // ColumnDef is used for parsing column definition from SQL. @@ -251,23 +263,25 @@ type ColumnDef struct { } // Accept implements Node Accept interface. -func (cd *ColumnDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (nod *ColumnDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := cd.Name.Accept(v) + nod = newNod.(*ColumnDef) + node, ok := nod.Name.Accept(v) if !ok { - return cd, false + return nod, false } - cd.Name = node.(*ColumnName) - for i, val := range cd.Options { + nod.Name = node.(*ColumnName) + for i, val := range nod.Options { node, ok := val.Accept(v) if !ok { - return cd, false + return nod, false } - cd.Options[i] = node.(*ColumnOption) + nod.Options[i] = node.(*ColumnOption) } - return v.Leave(cd) + return v.Leave(nod) } // CreateTableStmt is a statement to create a table. @@ -283,30 +297,32 @@ type CreateTableStmt struct { } // Accept implements Node Accept interface. -func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ct); skipChildren { - return ct, ok +func (nod *CreateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ct.Table.Accept(v) + nod = newNod.(*CreateTableStmt) + node, ok := nod.Table.Accept(v) if !ok { - return ct, false + return nod, false } - ct.Table = node.(*TableName) - for i, val := range ct.Cols { + nod.Table = node.(*TableName) + for i, val := range nod.Cols { node, ok = val.Accept(v) if !ok { - return ct, false + return nod, false } - ct.Cols[i] = node.(*ColumnDef) + nod.Cols[i] = node.(*ColumnDef) } - for i, val := range ct.Constraints { + for i, val := range nod.Constraints { node, ok = val.Accept(v) if !ok { - return ct, false + return nod, false } - ct.Constraints[i] = node.(*Constraint) + nod.Constraints[i] = node.(*Constraint) } - return v.Leave(ct) + return v.Leave(nod) } // DropTableStmt is a statement to drop one or more tables. @@ -319,18 +335,20 @@ type DropTableStmt struct { } // Accept implements Node Accept interface. -func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dt); skipChildren { - return dt, ok +func (nod *DropTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range dt.Tables { + nod = newNod.(*DropTableStmt) + for i, val := range nod.Tables { node, ok := val.Accept(v) if !ok { - return dt, false + return nod, false } - dt.Tables[i] = node.(*TableName) + nod.Tables[i] = node.(*TableName) } - return v.Leave(dt) + return v.Leave(nod) } // CreateIndexStmt is a statement to create an index. @@ -345,23 +363,25 @@ type CreateIndexStmt struct { } // Accept implements Node Accept interface. -func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ci); skipChildren { - return ci, ok +func (nod *CreateIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ci.Table.Accept(v) + nod = newNod.(*CreateIndexStmt) + node, ok := nod.Table.Accept(v) if !ok { - return ci, false + return nod, false } - ci.Table = node.(*TableName) - for i, val := range ci.IndexColNames { + nod.Table = node.(*TableName) + for i, val := range nod.IndexColNames { node, ok = val.Accept(v) if !ok { - return ci, false + return nod, false } - ci.IndexColNames[i] = node.(*IndexColName) + nod.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(ci) + return v.Leave(nod) } // DropIndexStmt is a statement to drop the index. @@ -375,16 +395,18 @@ type DropIndexStmt struct { } // Accept implements Node Accept interface. -func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(di); skipChildren { - return di, ok +func (nod *DropIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := di.Table.Accept(v) + nod = newNod.(*DropIndexStmt) + node, ok := nod.Table.Accept(v) if !ok { - return di, false + return nod, false } - di.Table = node.(*TableName) - return v.Leave(di) + nod.Table = node.(*TableName) + return v.Leave(nod) } // TableOptionType is the type for TableOption @@ -435,18 +457,20 @@ type ColumnPosition struct { } // Accept implements Node Accept interface. -func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cp); skipChildren { - return cp, ok +func (nod *ColumnPosition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if cp.RelativeColumn != nil { - node, ok := cp.RelativeColumn.Accept(v) + nod = newNod.(*ColumnPosition) + if nod.RelativeColumn != nil { + node, ok := nod.RelativeColumn.Accept(v) if !ok { - return cp, false + return nod, false } - cp.RelativeColumn = node.(*ColumnName) + nod.RelativeColumn = node.(*ColumnName) } - return v.Leave(cp) + return v.Leave(nod) } // AlterTableType is the type for AlterTableSpec. @@ -479,39 +503,41 @@ type AlterTableSpec struct { } // Accept implements Node Accept interface. -func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (nod *AlterTableSpec) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if as.Constraint != nil { - node, ok := as.Constraint.Accept(v) + nod = newNod.(*AlterTableSpec) + if nod.Constraint != nil { + node, ok := nod.Constraint.Accept(v) if !ok { - return as, false + return nod, false } - as.Constraint = node.(*Constraint) + nod.Constraint = node.(*Constraint) } - if as.Column != nil { - node, ok := as.Column.Accept(v) + if nod.Column != nil { + node, ok := nod.Column.Accept(v) if !ok { - return as, false + return nod, false } - as.Column = node.(*ColumnDef) + nod.Column = node.(*ColumnDef) } - if as.ColumnName != nil { - node, ok := as.ColumnName.Accept(v) + if nod.ColumnName != nil { + node, ok := nod.ColumnName.Accept(v) if !ok { - return as, false + return nod, false } - as.ColumnName = node.(*ColumnName) + nod.ColumnName = node.(*ColumnName) } - if as.Position != nil { - node, ok := as.Position.Accept(v) + if nod.Position != nil { + node, ok := nod.Position.Accept(v) if !ok { - return as, false + return nod, false } - as.Position = node.(*ColumnPosition) + nod.Position = node.(*ColumnPosition) } - return v.Leave(as) + return v.Leave(nod) } // AlterTableStmt is a statement to change the structure of a table. @@ -524,23 +550,25 @@ type AlterTableStmt struct { } // Accept implements Node Accept interface. -func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(at); skipChildren { - return at, ok +func (nod *AlterTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := at.Table.Accept(v) + nod = newNod.(*AlterTableStmt) + node, ok := nod.Table.Accept(v) if !ok { - return at, false + return nod, false } - at.Table = node.(*TableName) - for i, val := range at.Specs { + nod.Table = node.(*TableName) + for i, val := range nod.Specs { node, ok = val.Accept(v) if !ok { - return at, false + return nod, false } - at.Specs[i] = node.(*AlterTableSpec) + nod.Specs[i] = node.(*AlterTableSpec) } - return v.Leave(at) + return v.Leave(nod) } // TruncateTableStmt is a statement to empty a table completely. @@ -552,14 +580,16 @@ type TruncateTableStmt struct { } // Accept implements Node Accept interface. -func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (nod *TruncateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Table.Accept(v) + nod = newNod.(*TruncateTableStmt) + node, ok := nod.Table.Accept(v) if !ok { - return ts, false + return nod, false } - ts.Table = node.(*TableName) - return v.Leave(ts) + nod.Table = node.(*TableName) + return v.Leave(nod) } diff --git a/ast/dml.go b/ast/dml.go index b8ada81d99..f7f11421bd 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -55,34 +55,36 @@ type Join struct { // Tp represents join type. Tp JoinType // On represents join on condition. - On ExprNode + On *OnCondition } // Accept implements Node Accept interface. -func (j *Join) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(j); skipChildren { - return j, ok +func (nod *Join) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := j.Left.Accept(v) + nod = newNod.(*Join) + node, ok := nod.Left.Accept(v) if !ok { - return j, false + return nod, false } - j.Left = node.(ResultSetNode) - if j.Right != nil { - node, ok = j.Right.Accept(v) + nod.Left = node.(ResultSetNode) + if nod.Right != nil { + node, ok = nod.Right.Accept(v) if !ok { - return j, false + return nod, false } - j.Right = node.(ResultSetNode) + nod.Right = node.(ResultSetNode) } - if j.On != nil { - node, ok = j.On.Accept(v) + if nod.On != nil { + node, ok = nod.On.Accept(v) if !ok { - return j, false + return nod, false } - j.On = node.(ExprNode) + nod.On = node.(*OnCondition) } - return v.Leave(j) + return v.Leave(nod) } // TableName represents a table name. @@ -98,11 +100,13 @@ type TableName struct { } // Accept implements Node Accept interface. -func (tr *TableName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tr); skipChildren { - return tr, ok +func (nod *TableName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(tr) + nod = newNod.(*TableName) + return v.Leave(nod) } // TableSource represents table source with a name. @@ -118,47 +122,50 @@ type TableSource struct { } // Accept implements Node Accept interface. -func (ts *TableSource) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (nod *TableSource) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Source.Accept(v) + nod = newNod.(*TableSource) + node, ok := nod.Source.Accept(v) if !ok { - return ts, false + return nod, false } - ts.Source = node.(ResultSetNode) - return v.Leave(ts) + nod.Source = node.(ResultSetNode) + return v.Leave(nod) } -// SetResultFields implements ResultSet interface. -func (ts *TableSource) SetResultFields(rfs []*ResultField) { - ts.Source.SetResultFields(rfs) -} - -// GetResultFields implements ResultSet interface. -func (ts *TableSource) GetResultFields() []*ResultField { - return ts.Source.GetResultFields() -} - -// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. -type UnionClause struct { +// OnCondition represetns JOIN on condition. +type OnCondition struct { node - Distinct bool - Select *SelectStmt + Expr ExprNode } // Accept implements Node Accept interface. -func (uc *UnionClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(uc); skipChildren { - return uc, ok +func (nod *OnCondition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := uc.Select.Accept(v) + nod = newNod.(*OnCondition) + node, ok := nod.Expr.Accept(v) if !ok { - return uc, false + return nod, false } - uc.Select = node.(*SelectStmt) - return v.Leave(uc) + nod.Expr = node.(ExprNode) + return v.Leave(nod) +} + +// SetResultFields implements ResultSet interface. +func (nod *TableSource) SetResultFields(rfs []*ResultField) { + nod.Source.SetResultFields(rfs) +} + +// GetResultFields implements ResultSet interface. +func (nod *TableSource) GetResultFields() []*ResultField { + return nod.Source.GetResultFields() } // SelectLockType is the lock type for SelectStmt. @@ -175,22 +182,18 @@ const ( type WildCardField struct { node - Table *TableName + Table model.CIStr + Schema model.CIStr } // Accept implements Node Accept interface. -func (wf *WildCardField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(wf); skipChildren { - return wf, ok +func (nod *WildCardField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if wf.Table != nil { - node, ok := wf.Table.Accept(v) - if !ok { - return wf, false - } - wf.Table = node.(*TableName) - } - return v.Leave(wf) + nod = newNod.(*WildCardField) + return v.Leave(nod) } // SelectField represents fields in select statement. @@ -208,22 +211,70 @@ type SelectField struct { } // Accept implements Node Accept interface. -func (sf *SelectField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (nod *SelectField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if sf.Expr != nil { - node, ok := sf.Expr.Accept(v) + nod = newNod.(*SelectField) + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) if !ok { - return sf, false + return nod, false } - sf.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) } - return v.Leave(sf) + return v.Leave(nod) } -// OrderByItem represents a single order by item. -type OrderByItem struct { +// FieldList represents field list in select statement. +type FieldList struct { + node + + Fields []*SelectField +} + +// Accept implements Node Accept interface. +func (nod *FieldList) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*FieldList) + for i, val := range nod.Fields { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Fields[i] = node.(*SelectField) + } + return v.Leave(nod) +} + +// TableRefsClause represents table references clause in dml statement. +type TableRefsClause struct { + node + + TableRefs *Join +} + +// Accept implements Node Accept interface. +func (nod *TableRefsClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*TableRefsClause) + node, ok := nod.TableRefs.Accept(v) + if !ok { + return nod, false + } + nod.TableRefs = node.(*Join) + return v.Leave(nod) +} + +// ByItem represents an item in order by or group by. +type ByItem struct { node Expr ExprNode @@ -231,131 +282,236 @@ type OrderByItem struct { } // Accept implements Node Accept interface. -func (ob *OrderByItem) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ob); skipChildren { - return ob, ok +func (nod *ByItem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ob.Expr.Accept(v) + nod = newNod.(*ByItem) + node, ok := nod.Expr.Accept(v) if !ok { - return ob, false + return nod, false } - ob.Expr = node.(ExprNode) - return v.Leave(ob) + nod.Expr = node.(ExprNode) + return v.Leave(nod) +} + +// GroupByClause represents group by clause. +type GroupByClause struct { + node + Items []*ByItem +} + +// Accept implements Node Accept interface. +func (nod *GroupByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*GroupByClause) + for i, val := range nod.Items { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Items[i] = node.(*ByItem) + } + return v.Leave(nod) +} + +// HavingClause represents having clause. +type HavingClause struct { + node + Expr ExprNode +} + +// Accept implements Node Accept interface. +func (nod *HavingClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*HavingClause) + node, ok := nod.Expr.Accept(v) + if !ok { + return nod, false + } + nod.Expr = node.(ExprNode) + return v.Leave(nod) +} + +// OrderByClause represents order by clause. +type OrderByClause struct { + node + Items []*ByItem + ForUnion bool +} + +// Accept implements Node Accept interface. +func (nod *OrderByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*OrderByClause) + for i, val := range nod.Items { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Items[i] = node.(*ByItem) + } + return v.Leave(nod) } // SelectStmt represents the select query node. +// See: https://dev.mysql.com/doc/refman/5.7/en/select.html type SelectStmt struct { dmlNode resultSetNode // Distinct represents if the select has distinct option. Distinct bool - // Fields is the select expression list. - Fields []*SelectField // From is the from clause of the query. - From *Join + From *TableRefsClause // Where is the where clause in select statement. Where ExprNode + // Fields is the select expression list. + Fields *FieldList // GroupBy is the group by expression list. - GroupBy []ExprNode + GroupBy *GroupByClause // Having is the having condition. - Having ExprNode - // OrderBy is the odering expression list. - OrderBy []*OrderByItem + Having *HavingClause + // OrderBy is the ordering expression list. + OrderBy *OrderByClause // Limit is the limit clause. Limit *Limit // Lock is the lock type LockTp SelectLockType - - // Union clauses. - Unions []*UnionClause - // Order by for union select. - UnionOrderBy []*OrderByItem - // Limit for union select. - UnionLimit *Limit } // Accept implements Node Accept interface. -func (sn *SelectStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sn); skipChildren { - return sn, ok +func (nod *SelectStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range sn.Fields { - node, ok := val.Accept(v) + nod = newNod.(*SelectStmt) + + if nod.From != nil { + node, ok := nod.From.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Fields[i] = node.(*SelectField) - } - if sn.From != nil { - node, ok := sn.From.Accept(v) - if !ok { - return sn, false - } - sn.From = node.(*Join) + nod.From = node.(*TableRefsClause) } - if sn.Where != nil { - node, ok := sn.Where.Accept(v) + if nod.Where != nil { + node, ok := nod.Where.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - for i, val := range sn.GroupBy { - node, ok := val.Accept(v) + if nod.Fields != nil { + node, ok := nod.Fields.Accept(v) if !ok { - return sn, false + return nod, false } - sn.GroupBy[i] = node.(ExprNode) - } - if sn.Having != nil { - node, ok := sn.Having.Accept(v) - if !ok { - return sn, false - } - sn.Having = node.(ExprNode) + nod.Fields = node.(*FieldList) } - for i, val := range sn.OrderBy { - node, ok := val.Accept(v) + if nod.GroupBy != nil { + node, ok := nod.GroupBy.Accept(v) if !ok { - return sn, false + return nod, false } - sn.OrderBy[i] = node.(*OrderByItem) + nod.GroupBy = node.(*GroupByClause) } - if sn.Limit != nil { - node, ok := sn.Limit.Accept(v) + if nod.Having != nil { + node, ok := nod.Having.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Limit = node.(*Limit) + nod.Having = node.(*HavingClause) } - for i, val := range sn.Unions { + if nod.OrderBy != nil { + node, ok := nod.OrderBy.Accept(v) + if !ok { + return nod, false + } + nod.OrderBy = node.(*OrderByClause) + } + + if nod.Limit != nil { + node, ok := nod.Limit.Accept(v) + if !ok { + return nod, false + } + nod.Limit = node.(*Limit) + } + return v.Leave(nod) +} + +// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. +type UnionClause struct { + node + + Distinct bool + Select *SelectStmt +} + +// Accept implements Node Accept interface. +func (nod *UnionClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*UnionClause) + node, ok := nod.Select.Accept(v) + if !ok { + return nod, false + } + nod.Select = node.(*SelectStmt) + return v.Leave(nod) +} + +// UnionStmt represents "union statement" +// See: https://dev.mysql.com/doc/refman/5.7/en/union.html +type UnionStmt struct { + dmlNode + resultSetNode + + Distinct bool + Selects []*SelectStmt + Limit *Limit +} + +// Accept implements Node Accept interface. +func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*UnionStmt) + for i, val := range nod.Selects { node, ok := val.Accept(v) if !ok { - return sn, false + return nod, false } - sn.Unions[i] = node.(*UnionClause) + nod.Selects[i] = node.(*SelectStmt) } - for i, val := range sn.UnionOrderBy { - node, ok := val.Accept(v) + if nod.Limit != nil { + node, ok := nod.Limit.Accept(v) if !ok { - return sn, false + return nod, false } - sn.UnionOrderBy[i] = node.(*OrderByItem) + nod.Limit = node.(*Limit) } - if sn.UnionLimit != nil { - node, ok := sn.UnionLimit.Accept(v) - if !ok { - return sn, false - } - sn.UnionLimit = node.(*Limit) - } - return v.Leave(sn) + return v.Leave(nod) } // Assignment is the expression for assignment, like a = 1. @@ -368,21 +524,23 @@ type Assignment struct { } // Accept implements Node Accept interface. -func (as *Assignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (nod *Assignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := as.Column.Accept(v) + nod = newNod.(*Assignment) + node, ok := nod.Column.Accept(v) if !ok { - return as, false + return nod, false } - as.Column = node.(*ColumnName) - node, ok = as.Expr.Accept(v) + nod.Column = node.(*ColumnName) + node, ok = nod.Expr.Accept(v) if !ok { - return as, false + return nod, false } - as.Expr = node.(ExprNode) - return v.Leave(as) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // Priority const values. @@ -399,9 +557,9 @@ const ( type InsertStmt struct { dmlNode + Table *TableRefsClause Columns []*ColumnName Lists [][]ExprNode - Table *TableName Setlist []*Assignment Priority int OnDuplicate []*Assignment @@ -409,48 +567,57 @@ type InsertStmt struct { } // Accept implements Node Accept interface. -func (in *InsertStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(in); skipChildren { - return in, ok +func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range in.Columns { + nod = newNod.(*InsertStmt) + + node, ok := nod.Table.Accept(v) + if !ok { + return nod, false + } + nod.Table = node.(*TableRefsClause) + + for i, val := range nod.Columns { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.Columns[i] = node.(*ColumnName) + nod.Columns[i] = node.(*ColumnName) } - for i, list := range in.Lists { + for i, list := range nod.Lists { for j, val := range list { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.Lists[i][j] = node.(ExprNode) + nod.Lists[i][j] = node.(ExprNode) } } - for i, val := range in.Setlist { + for i, val := range nod.Setlist { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.Setlist[i] = node.(*Assignment) + nod.Setlist[i] = node.(*Assignment) } - for i, val := range in.OnDuplicate { + for i, val := range nod.OnDuplicate { node, ok := val.Accept(v) if !ok { - return in, false + return nod, false } - in.OnDuplicate[i] = node.(*Assignment) + nod.OnDuplicate[i] = node.(*Assignment) } - if in.Select != nil { - node, ok := in.Select.Accept(v) + if nod.Select != nil { + node, ok := nod.Select.Accept(v) if !ok { - return in, false + return nod, false } - in.Select = node.(*SelectStmt) + nod.Select = node.(*SelectStmt) } - return v.Leave(in) + return v.Leave(nod) } // DeleteStmt is a statement to delete rows from table. @@ -458,10 +625,12 @@ func (in *InsertStmt) Accept(v Visitor) (Node, bool) { type DeleteStmt struct { dmlNode - TableRefs *Join + // Used in both single table and multiple table delete statement. + TableRefs *TableRefsClause + // Only used in multiple table delete statement. Tables []*TableName Where ExprNode - Order []*OrderByItem + Order []*ByItem Limit *Limit LowPriority bool Ignore bool @@ -471,47 +640,49 @@ type DeleteStmt struct { } // Accept implements Node Accept interface. -func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(de); skipChildren { - return de, ok +func (nod *DeleteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } + nod = newNod.(*DeleteStmt) - node, ok := de.TableRefs.Accept(v) + node, ok := nod.TableRefs.Accept(v) if !ok { - return de, false + return nod, false } - de.TableRefs = node.(*Join) + nod.TableRefs = node.(*TableRefsClause) - for i, val := range de.Tables { + for i, val := range nod.Tables { node, ok = val.Accept(v) if !ok { - return de, false + return nod, false } - de.Tables[i] = node.(*TableName) + nod.Tables[i] = node.(*TableName) } - if de.Where != nil { - node, ok = de.Where.Accept(v) + if nod.Where != nil { + node, ok = nod.Where.Accept(v) if !ok { - return de, false + return nod, false } - de.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - for i, val := range de.Order { + for i, val := range nod.Order { node, ok = val.Accept(v) if !ok { - return de, false + return nod, false } - de.Order[i] = node.(*OrderByItem) + nod.Order[i] = node.(*ByItem) } - node, ok = de.Limit.Accept(v) + node, ok = nod.Limit.Accept(v) if !ok { - return de, false + return nod, false } - de.Limit = node.(*Limit) - return v.Leave(de) + nod.Limit = node.(*Limit) + return v.Leave(nod) } // UpdateStmt is a statement to update columns of existing rows in tables with new values. @@ -519,10 +690,10 @@ func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { type UpdateStmt struct { dmlNode - TableRefs *Join + TableRefs *TableRefsClause List []*Assignment Where ExprNode - Order []*OrderByItem + Order []*ByItem Limit *Limit LowPriority bool Ignore bool @@ -530,43 +701,45 @@ type UpdateStmt struct { } // Accept implements Node Accept interface. -func (up *UpdateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(up); skipChildren { - return up, ok +func (nod *UpdateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := up.TableRefs.Accept(v) + nod = newNod.(*UpdateStmt) + node, ok := nod.TableRefs.Accept(v) if !ok { - return up, false + return nod, false } - up.TableRefs = node.(*Join) - for i, val := range up.List { + nod.TableRefs = node.(*TableRefsClause) + for i, val := range nod.List { node, ok = val.Accept(v) if !ok { - return up, false + return nod, false } - up.List[i] = node.(*Assignment) + nod.List[i] = node.(*Assignment) } - if up.Where != nil { - node, ok = up.Where.Accept(v) + if nod.Where != nil { + node, ok = nod.Where.Accept(v) if !ok { - return up, false + return nod, false } - up.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - for i, val := range up.Order { + for i, val := range nod.Order { node, ok = val.Accept(v) if !ok { - return up, false + return nod, false } - up.Order[i] = node.(*OrderByItem) + nod.Order[i] = node.(*ByItem) } - node, ok = up.Limit.Accept(v) + node, ok = nod.Limit.Accept(v) if !ok { - return up, false + return nod, false } - up.Limit = node.(*Limit) - return v.Leave(up) + nod.Limit = node.(*Limit) + return v.Leave(nod) } // Limit is the limit clause. @@ -578,9 +751,11 @@ type Limit struct { } // Accept implements Node Accept interface. -func (l *Limit) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(l); skipChildren { - return l, ok +func (nod *Limit) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(l) + nod = newNod.(*Limit) + return v.Leave(nod) } diff --git a/ast/expressions.go b/ast/expressions.go index 558f1a019c..c99611ae13 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -52,16 +52,18 @@ type ValueExpr struct { } // IsStatic implements ExprNode interface. -func (val *ValueExpr) IsStatic() bool { +func (nod *ValueExpr) IsStatic() bool { return true } // Accept implements Node interface. -func (val *ValueExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(val); skipChildren { - return val, ok +func (nod *ValueExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(val) + nod = newNod.(*ValueExpr) + return v.Leave(nod) } // BetweenExpr is for "between and" or "not between and" expression. @@ -78,35 +80,37 @@ type BetweenExpr struct { } // Accept implements Node interface. -func (b *BetweenExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(b); skipChildren { - return b, ok +func (nod *BetweenExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } + nod = newNod.(*BetweenExpr) - node, ok := b.Expr.Accept(v) + node, ok := nod.Expr.Accept(v) if !ok { - return b, false + return nod, false } - b.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) - node, ok = b.Left.Accept(v) + node, ok = nod.Left.Accept(v) if !ok { - return b, false + return nod, false } - b.Left = node.(ExprNode) + nod.Left = node.(ExprNode) - node, ok = b.Right.Accept(v) + node, ok = nod.Right.Accept(v) if !ok { - return b, false + return nod, false } - b.Right = node.(ExprNode) + nod.Right = node.(ExprNode) - return v.Leave(b) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (b *BetweenExpr) IsStatic() bool { - return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic() +func (nod *BetweenExpr) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Left.IsStatic() && nod.Right.IsStatic() } // BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. @@ -121,29 +125,31 @@ type BinaryOperationExpr struct { } // Accept implements Node interface. -func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(o); skipChildren { - return o, ok +func (nod *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } + nod = newNod.(*BinaryOperationExpr) - node, ok := o.L.Accept(v) + node, ok := nod.L.Accept(v) if !ok { - return o, false + return nod, false } - o.L = node.(ExprNode) + nod.L = node.(ExprNode) - node, ok = o.R.Accept(v) + node, ok = nod.R.Accept(v) if !ok { - return o, false + return nod, false } - o.R = node.(ExprNode) + nod.R = node.(ExprNode) - return v.Leave(o) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (o *BinaryOperationExpr) IsStatic() bool { - return o.L.IsStatic() && o.R.IsStatic() +func (nod *BinaryOperationExpr) IsStatic() bool { + return nod.L.IsStatic() && nod.R.IsStatic() } // WhenClause is the when clause in Case expression for "when condition then result". @@ -156,27 +162,29 @@ type WhenClause struct { } // Accept implements Node Accept interface. -func (w *WhenClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(w); skipChildren { - return w, ok +func (nod *WhenClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := w.Expr.Accept(v) + nod = newNod.(*WhenClause) + node, ok := nod.Expr.Accept(v) if !ok { - return w, false + return nod, false } - w.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) - node, ok = w.Result.Accept(v) + node, ok = nod.Result.Accept(v) if !ok { - return w, false + return nod, false } - w.Result = node.(ExprNode) - return v.Leave(w) + nod.Result = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (w *WhenClause) IsStatic() bool { - return w.Expr.IsStatic() && w.Result.IsStatic() +func (nod *WhenClause) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Result.IsStatic() } // CaseExpr is the case expression. @@ -191,45 +199,47 @@ type CaseExpr struct { } // Accept implements Node Accept interface. -func (f *CaseExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (nod *CaseExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if f.Value != nil { - node, ok := f.Value.Accept(v) + nod = newNod.(*CaseExpr) + if nod.Value != nil { + node, ok := nod.Value.Accept(v) if !ok { - return f, false + return nod, false } - f.Value = node.(ExprNode) + nod.Value = node.(ExprNode) } - for i, val := range f.WhenClauses { + for i, val := range nod.WhenClauses { node, ok := val.Accept(v) if !ok { - return f, false + return nod, false } - f.WhenClauses[i] = node.(*WhenClause) + nod.WhenClauses[i] = node.(*WhenClause) } - if f.ElseClause != nil { - node, ok := f.ElseClause.Accept(v) + if nod.ElseClause != nil { + node, ok := nod.ElseClause.Accept(v) if !ok { - return f, false + return nod, false } - f.ElseClause = node.(ExprNode) + nod.ElseClause = node.(ExprNode) } - return v.Leave(f) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (f *CaseExpr) IsStatic() bool { - if f.Value != nil && !f.Value.IsStatic() { +func (nod *CaseExpr) IsStatic() bool { + if nod.Value != nil && !nod.Value.IsStatic() { return false } - for _, w := range f.WhenClauses { + for _, w := range nod.WhenClauses { if !w.IsStatic() { return false } } - if f.ElseClause != nil && !f.ElseClause.IsStatic() { + if nod.ElseClause != nil && !nod.ElseClause.IsStatic() { return false } return true @@ -243,26 +253,28 @@ type SubqueryExpr struct { } // Accept implements Node Accept interface. -func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sq); skipChildren { - return sq, ok +func (nod *SubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := sq.Query.Accept(v) + nod = newNod.(*SubqueryExpr) + node, ok := nod.Query.Accept(v) if !ok { - return sq, false + return nod, false } - sq.Query = node.(*SelectStmt) - return v.Leave(sq) + nod.Query = node.(*SelectStmt) + return v.Leave(nod) } // SetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) SetResultFields(rfs []*ResultField) { - sq.Query.SetResultFields(rfs) +func (nod *SubqueryExpr) SetResultFields(rfs []*ResultField) { + nod.Query.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) GetResultFields() []*ResultField { - return sq.Query.GetResultFields() +func (nod *SubqueryExpr) GetResultFields() []*ResultField { + return nod.Query.GetResultFields() } // CompareSubqueryExpr is the expression for "expr cmp (select ...)". @@ -282,21 +294,23 @@ type CompareSubqueryExpr struct { } // Accept implements Node Accept interface. -func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (nod *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := cs.L.Accept(v) + nod = newNod.(*CompareSubqueryExpr) + node, ok := nod.L.Accept(v) if !ok { - return cs, false + return nod, false } - cs.L = node.(ExprNode) - node, ok = cs.R.Accept(v) + nod.L = node.(ExprNode) + node, ok = nod.R.Accept(v) if !ok { - return cs, false + return nod, false } - cs.R = node.(*SubqueryExpr) - return v.Leave(cs) + nod.R = node.(*SubqueryExpr) + return v.Leave(nod) } // ColumnName represents column name. @@ -312,11 +326,13 @@ type ColumnName struct { } // Accept implements Node Accept interface. -func (cn *ColumnName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cn); skipChildren { - return cn, ok +func (nod *ColumnName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cn) + nod = newNod.(*ColumnName) + return v.Leave(nod) } // ColumnNameExpr represents a column name expression. @@ -328,16 +344,18 @@ type ColumnNameExpr struct { } // Accept implements Node Accept interface. -func (cr *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cr); skipChildren { - return cr, ok +func (nod *ColumnNameExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := cr.Name.Accept(v) + nod = newNod.(*ColumnNameExpr) + node, ok := nod.Name.Accept(v) if !ok { - return cr, false + return nod, false } - cr.Name = node.(*ColumnName) - return v.Leave(cr) + nod.Name = node.(*ColumnName) + return v.Leave(nod) } // DefaultExpr is the default expression using default value for a column. @@ -348,18 +366,20 @@ type DefaultExpr struct { } // Accept implements Node Accept interface. -func (d *DefaultExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(d); skipChildren { - return d, ok +func (nod *DefaultExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if d.Name != nil { - node, ok := d.Name.Accept(v) + nod = newNod.(*DefaultExpr) + if nod.Name != nil { + node, ok := nod.Name.Accept(v) if !ok { - return d, false + return nod, false } - d.Name = node.(*ColumnName) + nod.Name = node.(*ColumnName) } - return v.Leave(d) + return v.Leave(nod) } // IdentifierExpr represents an identifier expression. @@ -370,11 +390,13 @@ type IdentifierExpr struct { } // Accept implements Node Accept interface. -func (i *IdentifierExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(i); skipChildren { - return i, ok +func (nod *IdentifierExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(i) + nod = newNod.(*IdentifierExpr) + return v.Leave(nod) } // ExistsSubqueryExpr is the expression for "exists (select ...)". @@ -386,16 +408,18 @@ type ExistsSubqueryExpr struct { } // Accept implements Node Accept interface. -func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (nod *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Sel.Accept(v) + nod = newNod.(*ExistsSubqueryExpr) + node, ok := nod.Sel.Accept(v) if !ok { - return es, false + return nod, false } - es.Sel = node.(*SubqueryExpr) - return v.Leave(es) + nod.Sel = node.(*SubqueryExpr) + return v.Leave(nod) } // PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)". @@ -412,30 +436,32 @@ type PatternInExpr struct { } // Accept implements Node Accept interface. -func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pi); skipChildren { - return pi, ok +func (nod *PatternInExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := pi.Expr.Accept(v) + nod = newNod.(*PatternInExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return pi, false + return nod, false } - pi.Expr = node.(ExprNode) - for i, val := range pi.List { + nod.Expr = node.(ExprNode) + for i, val := range nod.List { node, ok = val.Accept(v) if !ok { - return pi, false + return nod, false } - pi.List[i] = node.(ExprNode) + nod.List[i] = node.(ExprNode) } - if pi.Sel != nil { - node, ok = pi.Sel.Accept(v) + if nod.Sel != nil { + node, ok = nod.Sel.Accept(v) if !ok { - return pi, false + return nod, false } - pi.Sel = node.(*SubqueryExpr) + nod.Sel = node.(*SubqueryExpr) } - return v.Leave(pi) + return v.Leave(nod) } // IsNullExpr is the expression for null check. @@ -448,21 +474,23 @@ type IsNullExpr struct { } // Accept implements Node Accept interface. -func (is *IsNullExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (nod *IsNullExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + nod = newNod.(*IsNullExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return is, false + return nod, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsNullExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (nod *IsNullExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // IsTruthExpr is the expression for true/false check. @@ -477,21 +505,23 @@ type IsTruthExpr struct { } // Accept implements Node Accept interface. -func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (nod *IsTruthExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + nod = newNod.(*IsTruthExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return is, false + return nod, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsTruthExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (nod *IsTruthExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // PatternLikeExpr is the expression for like operator, e.g, expr like "%123%" @@ -508,26 +538,28 @@ type PatternLikeExpr struct { } // Accept implements Node Accept interface. -func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pl); skipChildren { - return pl, ok +func (nod *PatternLikeExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := pl.Expr.Accept(v) + nod = newNod.(*PatternLikeExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return pl, false + return nod, false } - pl.Expr = node.(ExprNode) - node, ok = pl.Pattern.Accept(v) + nod.Expr = node.(ExprNode) + node, ok = nod.Pattern.Accept(v) if !ok { - return pl, false + return nod, false } - pl.Pattern = node.(ExprNode) - return v.Leave(pl) + nod.Pattern = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (pl *PatternLikeExpr) IsStatic() bool { - return pl.Expr.IsStatic() && pl.Pattern.IsStatic() +func (nod *PatternLikeExpr) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Pattern.IsStatic() } // ParamMarkerExpr expresion holds a place for another expression. @@ -537,11 +569,13 @@ type ParamMarkerExpr struct { } // Accept implements Node Accept interface. -func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pm); skipChildren { - return pm, ok +func (nod *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(pm) + nod = newNod.(*ParamMarkerExpr) + return v.Leave(nod) } // ParenthesesExpr is the parentheses expression. @@ -552,23 +586,25 @@ type ParenthesesExpr struct { } // Accept implements Node Accept interface. -func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (nod *ParenthesesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if p.Expr != nil { - node, ok := p.Expr.Accept(v) + nod = newNod.(*ParenthesesExpr) + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) if !ok { - return p, false + return nod, false } - p.Expr = node.(ExprNode) + nod.Expr = node.(ExprNode) } - return v.Leave(p) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (p *ParenthesesExpr) IsStatic() bool { - return p.Expr.IsStatic() +func (nod *ParenthesesExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // PositionExpr is the expression for order by and group by position. @@ -583,16 +619,18 @@ type PositionExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (p *PositionExpr) IsStatic() bool { +func (nod *PositionExpr) IsStatic() bool { return true } // Accept implements Node Accept interface. -func (p *PositionExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (nod *PositionExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(p) + nod = newNod.(*PositionExpr) + return v.Leave(nod) } // PatternRegexpExpr is the pattern expression for pattern match. @@ -607,26 +645,28 @@ type PatternRegexpExpr struct { } // Accept implements Node Accept interface. -func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (nod *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := p.Expr.Accept(v) + nod = newNod.(*PatternRegexpExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return p, false + return nod, false } - p.Expr = node.(ExprNode) - node, ok = p.Pattern.Accept(v) + nod.Expr = node.(ExprNode) + node, ok = nod.Pattern.Accept(v) if !ok { - return p, false + return nod, false } - p.Pattern = node.(ExprNode) - return v.Leave(p) + nod.Pattern = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (p *PatternRegexpExpr) IsStatic() bool { - return p.Expr.IsStatic() && p.Pattern.IsStatic() +func (nod *PatternRegexpExpr) IsStatic() bool { + return nod.Expr.IsStatic() && nod.Pattern.IsStatic() } // RowExpr is the expression for row constructor. @@ -638,23 +678,25 @@ type RowExpr struct { } // Accept implements Node Accept interface. -func (r *RowExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(r); skipChildren { - return r, ok +func (nod *RowExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range r.Values { + nod = newNod.(*RowExpr) + for i, val := range nod.Values { node, ok := val.Accept(v) if !ok { - return r, false + return nod, false } - r.Values[i] = node.(ExprNode) + nod.Values[i] = node.(ExprNode) } - return v.Leave(r) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (r *RowExpr) IsStatic() bool { - for _, v := range r.Values { +func (nod *RowExpr) IsStatic() bool { + for _, v := range nod.Values { if !v.IsStatic() { return false } @@ -672,21 +714,23 @@ type UnaryOperationExpr struct { } // Accept implements Node Accept interface. -func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(u); skipChildren { - return u, ok +func (nod *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := u.V.Accept(v) + nod = newNod.(*UnaryOperationExpr) + node, ok := nod.V.Accept(v) if !ok { - return u, false + return nod, false } - u.V = node.(ExprNode) - return v.Leave(u) + nod.V = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (u *UnaryOperationExpr) IsStatic() bool { - return u.V.IsStatic() +func (nod *UnaryOperationExpr) IsStatic() bool { + return nod.V.IsStatic() } // ValuesExpr is the expression used in INSERT VALUES @@ -697,16 +741,18 @@ type ValuesExpr struct { } // Accept implements Node Accept interface. -func (va *ValuesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (nod *ValuesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Column.Accept(v) + nod = newNod.(*ValuesExpr) + node, ok := nod.Column.Accept(v) if !ok { - return va, false + return nod, false } - va.Column = node.(*ColumnName) - return v.Leave(va) + nod.Column = node.(*ColumnName) + return v.Leave(nod) } // VariableExpr is the expression for variable. @@ -721,9 +767,11 @@ type VariableExpr struct { } // Accept implements Node Accept interface. -func (va *VariableExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (nod *VariableExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(va) + nod = newNod.(*VariableExpr) + return v.Leave(nod) } diff --git a/ast/functions.go b/ast/functions.go index 3843478313..4e052b3ccf 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -42,28 +42,30 @@ type FuncCallExpr struct { } // Accept implements Node interface. -func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(c); skipChildren { - return c, ok +func (nod *FuncCallExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range c.Args { + nod = newNod.(*FuncCallExpr) + for i, val := range nod.Args { node, ok := val.Accept(v) if !ok { - return c, false + return nod, false } - c.Args[i] = node.(ExprNode) + nod.Args[i] = node.(ExprNode) } - return v.Leave(c) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (c *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(c.F)] +func (nod *FuncCallExpr) IsStatic() bool { + v := builtin.Funcs[strings.ToLower(nod.F)] if v.F == nil || !v.IsStatic { return false } - for _, v := range c.Args { + for _, v := range nod.Args { if !v.IsStatic() { return false } @@ -81,21 +83,23 @@ type FuncExtractExpr struct { } // Accept implements Node Accept interface. -func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ex); skipChildren { - return ex, ok +func (nod *FuncExtractExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ex.Date.Accept(v) + nod = newNod.(*FuncExtractExpr) + node, ok := nod.Date.Accept(v) if !ok { - return ex, false + return nod, false } - ex.Date = node.(ExprNode) - return v.Leave(ex) + nod.Date = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (ex *FuncExtractExpr) IsStatic() bool { - return ex.Date.IsStatic() +func (nod *FuncExtractExpr) IsStatic() bool { + return nod.Date.IsStatic() } // FuncConvertExpr provides a way to convert data between different character sets. @@ -109,21 +113,23 @@ type FuncConvertExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncConvertExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (nod *FuncConvertExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (nod *FuncConvertExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + nod = newNod.(*FuncConvertExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return f, false + return nod, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // CastFunctionType is the type for cast function. @@ -149,21 +155,23 @@ type FuncCastExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncCastExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (nod *FuncCastExpr) IsStatic() bool { + return nod.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (nod *FuncCastExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + nod = newNod.(*FuncCastExpr) + node, ok := nod.Expr.Accept(v) if !ok { - return f, false + return nod, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + nod.Expr = node.(ExprNode) + return v.Leave(nod) } // FuncSubstringExpr returns the substring as specified. @@ -177,31 +185,33 @@ type FuncSubstringExpr struct { } // Accept implements Node Accept interface. -func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (nod *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := sf.StrExpr.Accept(v) + nod = newNod.(*FuncSubstringExpr) + node, ok := nod.StrExpr.Accept(v) if !ok { - return sf, false + return nod, false } - sf.StrExpr = node.(ExprNode) - node, ok = sf.Pos.Accept(v) + nod.StrExpr = node.(ExprNode) + node, ok = nod.Pos.Accept(v) if !ok { - return sf, false + return nod, false } - sf.Pos = node.(ExprNode) - node, ok = sf.Len.Accept(v) + nod.Pos = node.(ExprNode) + node, ok = nod.Len.Accept(v) if !ok { - return sf, false + return nod, false } - sf.Len = node.(ExprNode) - return v.Leave(sf) + nod.Len = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (sf *FuncSubstringExpr) IsStatic() bool { - return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic() +func (nod *FuncSubstringExpr) IsStatic() bool { + return nod.StrExpr.IsStatic() && nod.Pos.IsStatic() && nod.Len.IsStatic() } // FuncSubstringIndexExpr returns the substring as specified. @@ -215,26 +225,28 @@ type FuncSubstringIndexExpr struct { } // Accept implements Node Accept interface. -func (si *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(si); skipChildren { - return si, ok +func (nod *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := si.StrExpr.Accept(v) + nod = newNod.(*FuncSubstringIndexExpr) + node, ok := nod.StrExpr.Accept(v) if !ok { - return si, false + return nod, false } - si.StrExpr = node.(ExprNode) - node, ok = si.Delim.Accept(v) + nod.StrExpr = node.(ExprNode) + node, ok = nod.Delim.Accept(v) if !ok { - return si, false + return nod, false } - si.Delim = node.(ExprNode) - node, ok = si.Count.Accept(v) + nod.Delim = node.(ExprNode) + node, ok = nod.Count.Accept(v) if !ok { - return si, false + return nod, false } - si.Count = node.(ExprNode) - return v.Leave(si) + nod.Count = node.(ExprNode) + return v.Leave(nod) } // FuncLocateExpr returns the position of the first occurrence of substring. @@ -248,26 +260,28 @@ type FuncLocateExpr struct { } // Accept implements Node Accept interface. -func (le *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(le); skipChildren { - return le, ok +func (nod *FuncLocateExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := le.Str.Accept(v) + nod = newNod.(*FuncLocateExpr) + node, ok := nod.Str.Accept(v) if !ok { - return le, false + return nod, false } - le.Str = node.(ExprNode) - node, ok = le.SubStr.Accept(v) + nod.Str = node.(ExprNode) + node, ok = nod.SubStr.Accept(v) if !ok { - return le, false + return nod, false } - le.SubStr = node.(ExprNode) - node, ok = le.Pos.Accept(v) + nod.SubStr = node.(ExprNode) + node, ok = nod.Pos.Accept(v) if !ok { - return le, false + return nod, false } - le.Pos = node.(ExprNode) - return v.Leave(le) + nod.Pos = node.(ExprNode) + return v.Leave(nod) } // TrimDirectionType is the type for trim direction. @@ -295,26 +309,28 @@ type FuncTrimExpr struct { } // Accept implements Node Accept interface. -func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tf); skipChildren { - return tf, ok +func (nod *FuncTrimExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := tf.Str.Accept(v) + nod = newNod.(*FuncTrimExpr) + node, ok := nod.Str.Accept(v) if !ok { - return tf, false + return nod, false } - tf.Str = node.(ExprNode) - node, ok = tf.RemStr.Accept(v) + nod.Str = node.(ExprNode) + node, ok = nod.RemStr.Accept(v) if !ok { - return tf, false + return nod, false } - tf.RemStr = node.(ExprNode) - return v.Leave(tf) + nod.RemStr = node.(ExprNode) + return v.Leave(nod) } // IsStatic implements the ExprNode IsStatic interface. -func (tf *FuncTrimExpr) IsStatic() bool { - return tf.Str.IsStatic() && tf.RemStr.IsStatic() +func (nod *FuncTrimExpr) IsStatic() bool { + return nod.Str.IsStatic() && nod.RemStr.IsStatic() } // TypeStar is a special type for "*". diff --git a/ast/misc.go b/ast/misc.go index 2ad94e7755..d269e27c84 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -57,16 +57,18 @@ type ExplainStmt struct { } // Accept implements Node Accept interface. -func (es *ExplainStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (nod *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Stmt.Accept(v) + nod = newNod.(*ExplainStmt) + node, ok := nod.Stmt.Accept(v) if !ok { - return es, false + return nod, false } - es.Stmt = node.(DMLNode) - return v.Leave(es) + nod.Stmt = node.(DMLNode) + return v.Leave(nod) } // PrepareStmt is a statement to prepares a SQL statement which contains placeholders, @@ -83,16 +85,18 @@ type PrepareStmt struct { } // Accept implements Node Accept interface. -func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ps); skipChildren { - return ps, ok +func (nod *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := ps.SQLVar.Accept(v) + nod = newNod.(*PrepareStmt) + node, ok := nod.SQLVar.Accept(v) if !ok { - return ps, false + return nod, false } - ps.SQLVar = node.(*VariableExpr) - return v.Leave(ps) + nod.SQLVar = node.(*VariableExpr) + return v.Leave(nod) } // DeallocateStmt is a statement to release PreparedStmt. @@ -105,11 +109,13 @@ type DeallocateStmt struct { } // Accept implements Node Accept interface. -func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ds); skipChildren { - return ds, ok +func (nod *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(ds) + nod = newNod.(*DeallocateStmt) + return v.Leave(nod) } // ExecuteStmt is a statement to execute PreparedStmt. @@ -123,18 +129,20 @@ type ExecuteStmt struct { } // Accept implements Node Accept interface. -func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (nod *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range es.UsingVars { + nod = newNod.(*ExecuteStmt) + for i, val := range nod.UsingVars { node, ok := val.Accept(v) if !ok { - return es, false + return nod, false } - es.UsingVars[i] = node.(ExprNode) + nod.UsingVars[i] = node.(ExprNode) } - return v.Leave(es) + return v.Leave(nod) } // ShowStmtType is the type for SHOW statement. @@ -173,39 +181,41 @@ type ShowStmt struct { } // Accept implements Node Accept interface. -func (ss *ShowStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ss); skipChildren { - return ss, ok +func (nod *ShowStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - if ss.Table != nil { - node, ok := ss.Table.Accept(v) + nod = newNod.(*ShowStmt) + if nod.Table != nil { + node, ok := nod.Table.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Table = node.(*TableName) + nod.Table = node.(*TableName) } - if ss.Column != nil { - node, ok := ss.Column.Accept(v) + if nod.Column != nil { + node, ok := nod.Column.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Column = node.(*ColumnName) + nod.Column = node.(*ColumnName) } - if ss.Pattern != nil { - node, ok := ss.Pattern.Accept(v) + if nod.Pattern != nil { + node, ok := nod.Pattern.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Pattern = node.(*PatternLikeExpr) + nod.Pattern = node.(*PatternLikeExpr) } - if ss.Where != nil { - node, ok := ss.Where.Accept(v) + if nod.Where != nil { + node, ok := nod.Where.Accept(v) if !ok { - return ss, false + return nod, false } - ss.Where = node.(ExprNode) + nod.Where = node.(ExprNode) } - return v.Leave(ss) + return v.Leave(nod) } // BeginStmt is a statement to start a new transaction. @@ -215,11 +225,13 @@ type BeginStmt struct { } // Accept implements Node Accept interface. -func (bs *BeginStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(bs); skipChildren { - return bs, ok +func (nod *BeginStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(bs) + nod = newNod.(*BeginStmt) + return v.Leave(nod) } // CommitStmt is a statement to commit the current transaction. @@ -229,11 +241,13 @@ type CommitStmt struct { } // Accept implements Node Accept interface. -func (cs *CommitStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (nod *CommitStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cs) + nod = newNod.(*CommitStmt) + return v.Leave(nod) } // RollbackStmt is a statement to roll back the current transaction. @@ -243,11 +257,13 @@ type RollbackStmt struct { } // Accept implements Node Accept interface. -func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rs); skipChildren { - return rs, ok +func (nod *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(rs) + nod = newNod.(*RollbackStmt) + return v.Leave(nod) } // UseStmt is a statement to use the DBName database as the current database. @@ -259,11 +275,13 @@ type UseStmt struct { } // Accept implements Node Accept interface. -func (us *UseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(us); skipChildren { - return us, ok +func (nod *UseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(us) + nod = newNod.(*UseStmt) + return v.Leave(nod) } // VariableAssignment is a variable assignment struct. @@ -276,16 +294,18 @@ type VariableAssignment struct { } // Accept implements Node interface. -func (va *VariableAssignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (nod *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Value.Accept(v) + nod = newNod.(*VariableAssignment) + node, ok := nod.Value.Accept(v) if !ok { - return va, false + return nod, false } - va.Value = node.(ExprNode) - return v.Leave(va) + nod.Value = node.(ExprNode) + return v.Leave(nod) } // SetStmt is the statement to set variables. @@ -296,18 +316,20 @@ type SetStmt struct { } // Accept implements Node Accept interface. -func (set *SetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (nod *SetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range set.Variables { + nod = newNod.(*SetStmt) + for i, val := range nod.Variables { node, ok := val.Accept(v) if !ok { - return set, false + return nod, false } - set.Variables[i] = node.(*VariableAssignment) + nod.Variables[i] = node.(*VariableAssignment) } - return v.Leave(set) + return v.Leave(nod) } // SetCharsetStmt is a statement to assign values to character and collation variables. @@ -320,11 +342,13 @@ type SetCharsetStmt struct { } // Accept implements Node Accept interface. -func (set *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (nod *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + nod = newNod.(*SetCharsetStmt) + return v.Leave(nod) } // SetPwdStmt is a statement to assign a password to user account. @@ -337,11 +361,13 @@ type SetPwdStmt struct { } // Accept implements Node Accept interface. -func (set *SetPwdStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (nod *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + nod = newNod.(*SetPwdStmt) + return v.Leave(nod) } // UserSpec is used for parsing create user statement. @@ -367,16 +393,18 @@ type DoStmt struct { } // Accept implements Node Accept interface. -func (do *DoStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(do); skipChildren { - return do, ok +func (nod *DoStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) } - for i, val := range do.Exprs { + nod = newNod.(*DoStmt) + for i, val := range nod.Exprs { node, ok := val.Accept(v) if !ok { - return do, false + return nod, false } - do.Exprs[i] = node.(ExprNode) + nod.Exprs[i] = node.(ExprNode) } - return v.Leave(do) + return v.Leave(nod) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 24cde087e6..97790fd0d5 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -394,7 +394,7 @@ import ( FieldAsName "Field alias name" FieldAsNameOpt "Field alias name opt" FieldList "field expression list" - FromClause "From clause" + TableRefsClause "Table references clause" Function "function expr" FunctionCallAgg "Function call on aggregate data" FunctionCallConflict "Function call with reserved keyword as function name" @@ -404,7 +404,6 @@ import ( FuncDatetimePrec "Function datetime precision" GlobalScope "The scope of variable" GroupByClause "GROUP BY clause" - GroupByList "GROUP BY list" HashString "Hashed string" HavingClause "HAVING clause" IfExists "If Exists" @@ -437,9 +436,9 @@ import ( OptInteger "Optional Integer keyword" Order "ORDER BY clause optional collation specification" OrderBy "ORDER BY clause" - OrderByItem "ORDER BY list item" + ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - OrderByList "ORDER BY list" + ByList "BY list" OuterOpt "optional OUTER clause" QuickOptional "QUICK or empty" PasswordOpt "Password option" @@ -452,7 +451,6 @@ import ( ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" RollbackStmt "ROLLBACK statement" - SelectBasic "Basic SELECT statement without parentheses and union" SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," SelectStmt "SELECT statement" SelectStmtCalcFoundRows "SELECT statement optional SQL_CALC_FOUND_ROWS" @@ -491,9 +489,8 @@ import ( TrimDirection "Trim string direction" TruncateTableStmt "TRANSACTION TABLE statement" UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionClause "Union select" + UnionStmt "Union select state ment" UnionClauseList "Union select clause list" - UnionClauseP "Union (select)" UnionClausePList "Union (select) clause list" UpdateStmt "UPDATE statement" Username "Username" @@ -1141,12 +1138,13 @@ DeleteFromStmt: "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause { // Single Table + join := &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil} x := &ast.DeleteStmt{ - TableRefs: &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil}, + TableRefs: &ast.TableRefsClause{TableRefs: join}, LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - Order: $8.([]*ast.OrderByItem), + Order: $8.([]*ast.ByItem), } if $7 != nil { x.Where = $7.(ast.ExprNode) @@ -1171,7 +1169,7 @@ DeleteFromStmt: MultiTable: true, BeforeFrom: true, Tables: $5.([]*ast.TableName), - TableRefs: $7.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, } if $8 != nil { x.Where = $8.(ast.ExprNode) @@ -1190,7 +1188,7 @@ DeleteFromStmt: Ignore: $4.(bool), MultiTable: true, Tables: $6.([]*ast.TableName), - TableRefs: $8.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, } if $9 != nil { x.Where = $9.(ast.ExprNode) @@ -1493,13 +1491,13 @@ Field: } | Identifier '.' '*' { - tn := &ast.TableName{Name:model.NewCIStr($1.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} + wildCard := &ast.WildCardField{Table: model.NewCIStr($1.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Identifier '.' Identifier '.' '*' { - tn := &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} + wildCard := &ast.WildCardField{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Expression FieldAsNameOpt { @@ -1545,19 +1543,9 @@ FieldList: } GroupByClause: - "GROUP" "BY" GroupByList + "GROUP" "BY" ByList { - $$ = $3.([]ast.ExprNode) - } - -GroupByList: - Expression - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| GroupByList ',' Expression - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) + $$ = &ast.GroupByClause{Items: $3.([]*ast.ByItem)} } HavingClause: @@ -1566,7 +1554,7 @@ HavingClause: } | "HAVING" Expression { - $$ = $2.(ast.ExprNode) + $$ = &ast.HavingClause{Expr: $2.(ast.ExprNode)} } IfExists: @@ -1643,7 +1631,9 @@ InsertIntoStmt: { x := $6.(*ast.InsertStmt) x.Priority = $2.(int) - x.Table = $5.(*ast.TableName) + // Wraps many layers here so that it can be processed the same way as select statement. + ts := &ast.TableSource{Source: $5.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} if $7 != nil { x.OnDuplicate = $7.([]*ast.Assignment) } @@ -1804,25 +1794,25 @@ Operand: } OrderBy: - "ORDER" "BY" OrderByList + "ORDER" "BY" ByList { - $$ = $3.([]*ast.OrderByItem) + $$ = &ast.OrderByClause{Items: $3.([]*ast.ByItem)} } -OrderByList: - OrderByItem +ByList: + ByItem { - $$ = []*ast.OrderByItem{$1.(*ast.OrderByItem)} + $$ = []*ast.ByItem{$1.(*ast.ByItem)} } -| OrderByList ',' OrderByItem +| ByList ',' ByItem { - $$ = append($1.([]*ast.OrderByItem), $3.(*ast.OrderByItem)) + $$ = append($1.([]*ast.ByItem), $3.(*ast.ByItem)) } -OrderByItem: +ByItem: Expression Order { - $$ = &ast.OrderByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} + $$ = &ast.ByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} } Order: @@ -2556,40 +2546,23 @@ RollbackStmt: } SelectStmt: - SelectBasic -| SelectBasic UnionClauseList - { - st := $1.(*ast.SelectStmt) - st.Unions = $2.([]*ast.UnionClause) - $$ = st - } -| SubSelect UnionClausePList OrderByOptional SelectStmtLimit - { - st := $1.(*ast.SubqueryExpr).Query - st.Unions = $2.([]*ast.UnionClause) - st.UnionOrderBy = $3.([]*ast.OrderByItem) - st.UnionLimit = $4.(*ast.Limit) - $$ = st - } - -SelectBasic: "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt { $$ = &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), + Fields: $3.(*ast.FieldList), From: nil, LockTp: $6.(ast.SelectLockType), } } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional + TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional SelectStmtLimit SelectLockOpt { st := &ast.SelectStmt{ Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), - From: $5.(*ast.Join), + Fields: $3.(*ast.FieldList), + From: $5.(*ast.TableRefsClause), LockTp: $11.(ast.SelectLockType), } @@ -2598,15 +2571,15 @@ SelectBasic: } if $7 != nil { - st.GroupBy = $7.([]ast.ExprNode) + st.GroupBy = $7.(*ast.GroupByClause) } if $8 != nil { - st.Having = $8.(ast.ExprNode) + st.Having = $8.(*ast.HavingClause) } if $9 != nil { - st.OrderBy = $9.([]*ast.OrderByItem) + st.OrderBy = $9.(*ast.OrderByClause) } if $10 != nil { @@ -2621,17 +2594,17 @@ FromDual: | "FROM" "DUAL" -FromClause: +TableRefsClause: TableRefs { - $$ = $1 + $$ = &ast.TableRefsClause{TableRefs: $1.(*ast.Join)} } TableRefs: EscapedTableRef { if j, ok := $1.(*ast.Join); ok { - // if $1 is JoinRset, use it directly + // if $1 is Join, use it directly $$ = j } else { $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} @@ -2708,11 +2681,13 @@ JoinTable: } | TableRef CrossOpt TableRef "ON" Expression { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: $5.(ast.ExprNode)} + on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} } | TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: $7.(ast.ExprNode)} + { + on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} } /* Support Using */ @@ -2799,7 +2774,7 @@ SelectStmtCalcFoundRows: SelectStmtFieldList: FieldList { - $$ = $1 + $$ = &ast.FieldList{Fields: $1.([]*ast.SelectField)} } SelectStmtGroup: @@ -2836,36 +2811,61 @@ SelectLockOpt: } // See: https://dev.mysql.com/doc/refman/5.7/en/union.html -UnionClause: - "UNION" UnionOpt SelectBasic +UnionStmt: + UnionClauseList { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SelectStmt)} + $$ = $1.(*ast.UnionStmt) } - -UnionClauseP: - "UNION" UnionOpt SubSelect +| UnionClausePList OrderByOptional SelectStmtLimit { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SubqueryExpr).Query} + union := $1.(*ast.UnionStmt) + if $2 != nil { + // push union order by into select statements. + orderBy := $2.(*ast.OrderByClause) + cloner := &ast.Cloner{} + for _, s := range union.Selects { + node, _ := orderBy.Accept(cloner) + s.OrderBy = node.(*ast.OrderByClause) + } + } + if $3 != nil { + union.Limit = $3.(*ast.Limit) + } + $$ = union } UnionClauseList: - UnionClause + SelectStmt "UNION" UnionOpt SelectStmt { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt), $4.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Distinct: $3.(bool), + Selects: selects, + } } -| UnionClauseList UnionClause +| UnionClauseList "UNION" UnionOpt SelectStmt { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union } UnionClausePList: - UnionClauseP + '(' SelectStmt ')' "UNION" UnionOpt '(' SelectStmt ')' { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} + selects := []*ast.SelectStmt{$2.(*ast.SelectStmt), $7.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Distinct: $5.(bool), + Selects: selects, + } } -| UnionClausePList UnionClauseP +| UnionClausePList "UNION" UnionOpt '(' SelectStmt ')' { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) + $$ = union } UnionOpt: @@ -3071,7 +3071,7 @@ ShowStmt: Where: $3.(ast.ExprNode), } } -| "SHOW" "CREATE" "TABLE" TableName +| "SHOW" "CREATE" "TABLE" TableName { $$ = &ast.ShowStmt{ Tp: ast.ShowCreateTable, @@ -3794,17 +3794,17 @@ UpdateStmt: "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { // Single-table syntax - r := &ast.Join{Left: $4.(ast.ResultSetNode), Right: nil} + join := &ast.Join{Left: &ast.TableSource{Source:$4.(ast.ResultSetNode)}, Right: nil} st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: r, + TableRefs: &ast.TableRefsClause{TableRefs: join}, List: $6.([]*ast.Assignment), } if $7 != nil { st.Where = $7.(ast.ExprNode) } if $8 != nil { - st.Order = $8.([]*ast.OrderByItem) + st.Order = $8.([]*ast.ByItem) } if $9 != nil { st.Limit = $9.(*ast.Limit) @@ -3819,7 +3819,7 @@ UpdateStmt: // Multiple-table syntax st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: $4.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, List: $6.([]*ast.Assignment), MultipleTable: true, } diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go index 6b32ea65fc..1e2798c28c 100644 --- a/ast/parser/parser_test.go +++ b/ast/parser/parser_test.go @@ -56,8 +56,8 @@ func (s *testParserSuite) TestSimple(c *C) { st := l.Stmts()[0] ss, ok := st.(*ast.SelectStmt) c.Assert(ok, IsTrue) - c.Assert(len(ss.Fields), Equals, 1) - cv, ok := ss.Fields[0].Expr.(*ast.FuncCastExpr) + c.Assert(len(ss.Fields.Fields), Equals, 1) + cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) c.Assert(ok, IsTrue) c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) diff --git a/ast/parser/yy_parser.go b/ast/parser/yy_parser.go new file mode 100644 index 0000000000..dfc3264796 --- /dev/null +++ b/ast/parser/yy_parser.go @@ -0,0 +1,19 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +// YYParse is an wrapper of `yyParse` to make it exported. +func YYParse(yylex yyLexer) int { + return yyParse(yylex) +} diff --git a/optimizer/infobinder.go b/optimizer/infobinder.go new file mode 100644 index 0000000000..0ad59bceb6 --- /dev/null +++ b/optimizer/infobinder.go @@ -0,0 +1,441 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" +) + +// InfoBinder binds schema information for table name and column name and set result fields +// for ResetSetNode. +// We need to know which table a table name refers to, which column a column name refers to. +// +// In general, a reference can only refer to information that are available for it. +// So children elements are visited in the order that previous elements make information +// available for following elements. +// +// During visiting, information are collected and stored in binderContext. +// When we enter a sub query, a new binderContext is pushed to the contextStack, so sub query +// information can overwrite outer query information. When we look up for a column reference, +// we look up from top to bottom in the contextStack. +type InfoBinder struct { + Info infoschema.InfoSchema + DefaultSchema model.CIStr + Err error + + contextStack []*binderContext +} + +// binderContext stores information that table name and column name +// can be bind to. +type binderContext struct { + /* For Select Statement. */ + // table map to lookup and check table name conflict. + tableMap map[string]int + // tableSources collected in from clause. + tables []*ast.TableSource + // result fields collected in select field list. + fieldList []*ast.ResultField + // result fields collected in group by clause. + groupBy []*ast.ResultField + + // The join node stack is used by on condition to find out + // available tables to reference. On condition can only + // refer to tables involved in current join. + joinNodeStack []*ast.Join + + // When visiting TableRefs, tables in this context are not available + // because it is being collected. + inTableRefs bool + // When visiting on conditon only tables in current join node are available. + inOnCondition bool + // When visiting field list, fieldList in this context are not available. + inFieldList bool + // When visiting group by, groupBy fields are not available. + inGroupBy bool + // When visiting having, only fieldList and groupBy fields are available. + inHaving bool +} + +// currentContext gets the current binder context. +func (sb *InfoBinder) currentContext() *binderContext { + stackLen := len(sb.contextStack) + if stackLen == 0 { + return nil + } + return sb.contextStack[stackLen-1] +} + +// pushContext is called when we enter a statement. +func (sb *InfoBinder) pushContext() { + sb.contextStack = append(sb.contextStack, &binderContext{ + tableMap: map[string]int{}, + }) +} + +// popContext is called when we leave a statement. +func (sb *InfoBinder) popContext() { + sb.contextStack = sb.contextStack[:len(sb.contextStack)-1] +} + +// pushJoin is called when we enter a join node. +func (sb *InfoBinder) pushJoin(j *ast.Join) { + ctx := sb.currentContext() + ctx.joinNodeStack = append(ctx.joinNodeStack, j) +} + +// popJoin is called when we leave a join node. +func (sb *InfoBinder) popJoin() { + ctx := sb.currentContext() + ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1] +} + +// Enter implements ast.Visitor interface. +func (sb *InfoBinder) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { + switch v := inNode.(type) { + case *ast.SelectStmt: + sb.pushContext() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = true + case *ast.Join: + sb.pushJoin(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = true + case *ast.FieldList: + sb.currentContext().inFieldList = true + case *ast.GroupByClause: + sb.currentContext().inGroupBy = true + case *ast.HavingClause: + sb.currentContext().inHaving = true + case *ast.InsertStmt: + sb.pushContext() + case *ast.DeleteStmt: + sb.pushContext() + case *ast.UpdateStmt: + sb.pushContext() + } + return inNode, false +} + +// Leave implements ast.Visitor interface. +func (sb *InfoBinder) Leave(inNode ast.Node) (node ast.Node, ok bool) { + switch v := inNode.(type) { + case *ast.TableName: + sb.handleTableName(v) + case *ast.ColumnName: + sb.handleColumnName(v) + case *ast.TableSource: + sb.handleTableSource(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = false + case *ast.Join: + sb.handleJoin(v) + sb.popJoin() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = false + case *ast.FieldList: + sb.handleFieldList(v) + sb.currentContext().inFieldList = false + case *ast.GroupByClause: + sb.currentContext().inGroupBy = false + case *ast.HavingClause: + sb.currentContext().inHaving = false + case *ast.SelectStmt: + v.SetResultFields(sb.currentContext().fieldList) + sb.popContext() + case *ast.InsertStmt: + sb.popContext() + case *ast.DeleteStmt: + sb.popContext() + case *ast.UpdateStmt: + sb.popContext() + } + return inNode, sb.Err == nil +} + +// handleTableName looks up and bind the schema information for table name +// and set result fields for table name. +func (sb *InfoBinder) handleTableName(tn *ast.TableName) { + if tn.Schema.L == "" { + tn.Schema = sb.DefaultSchema + } + table, err := sb.Info.TableByName(tn.Schema, tn.Name) + if err != nil { + sb.Err = err + return + } + tn.TableInfo = table.Meta() + dbInfo, _ := sb.Info.SchemaByName(tn.Schema) + tn.DBInfo = dbInfo + + rfs := make([]*ast.ResultField, len(tn.TableInfo.Columns)) + for i, v := range tn.TableInfo.Columns { + rfs[i] = &ast.ResultField{ + Column: v, + Table: tn.TableInfo, + DBName: tn.Schema, + } + } + tn.SetResultFields(rfs) + return +} + +// handleTableSources checks name duplication +// and puts the table source in current binderContext. +func (sb *InfoBinder) handleTableSource(ts *ast.TableSource) { + for _, v := range ts.GetResultFields() { + v.TableAsName = ts.AsName + } + var name string + if ts.AsName.L != "" { + name = ts.AsName.L + } else { + tableName := ts.Source.(*ast.TableName) + name = sb.tableUniqueName(tableName.Schema, tableName.Name) + } + ctx := sb.currentContext() + if _, ok := ctx.tableMap[name]; ok { + sb.Err = errors.Errorf("duplicated table/alias name %s", name) + return + } + ctx.tableMap[name] = len(ctx.tables) + ctx.tables = append(ctx.tables, ts) + return +} + +// handleJoin sets result fields for join. +func (sb *InfoBinder) handleJoin(j *ast.Join) { + if j.Right == nil { + j.SetResultFields(j.Left.GetResultFields()) + return + } + leftLen := len(j.Left.GetResultFields()) + rightLen := len(j.Right.GetResultFields()) + rfs := make([]*ast.ResultField, leftLen+rightLen) + copy(rfs, j.Left.GetResultFields()) + copy(rfs[leftLen:], j.Right.GetResultFields()) + j.SetResultFields(rfs) +} + +// handleColumnName looks up and binds schema information to +// the column name. +func (sb *InfoBinder) handleColumnName(cn *ast.ColumnName) { + ctx := sb.currentContext() + if ctx.inOnCondition { + // In on condition, only tables within current join is available. + sb.bindColumnNameInOnCondition(cn) + return + } + + // Try to bind the column name form top to bottom in the context stack. + for i := len(sb.contextStack) - 1; i >= 0; i-- { + if sb.bindColumnNameInContext(sb.contextStack[i], cn) { + // Column is already bound or encountered an error. + return + } + } + sb.Err = errors.Errorf("Unknown column %s", cn.Name.L) +} + +// bindColumnNameInContext looks up and binds schema information for a column with the ctx. +func (sb *InfoBinder) bindColumnNameInContext(ctx *binderContext, cn *ast.ColumnName) (done bool) { + if cn.Table.L == "" { + // If qualified table name is not specified in column name, the column name may be ambiguous, + // We need to iterate over all tables and + } + + if ctx.inTableRefs { + // In TableRefsClause, column reference only in join on condition which is handled before. + return false + } + if ctx.inFieldList { + // only bind column using tables. + return sb.bindColumnInTableSources(cn, ctx.tables) + } + if ctx.inGroupBy { + // field list first, then tables. + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + return sb.bindColumnInTableSources(cn, ctx.tables) + } + // column name in other places can be looked up in the same order. + if sb.bindColumnInResultFields(cn, ctx.groupBy) { + return true + } + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + + // tables is not available for having clause. + if !ctx.inHaving { + return sb.bindColumnInTableSources(cn, ctx.tables) + } + return false +} + +// bindColumnNameInOnCondition looks up for column name in current join, and +// binds the schema information. +func (sb *InfoBinder) bindColumnNameInOnCondition(cn *ast.ColumnName) { + ctx := sb.currentContext() + join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1] + tableSources := appendTableSources(nil, join) + if !sb.bindColumnInTableSources(cn, tableSources) { + sb.Err = errors.Errorf("unkown column name %s", cn.Name.O) + } +} + +func (sb *InfoBinder) bindColumnInTableSources(cn *ast.ColumnName, tableSources []*ast.TableSource) (done bool) { + var matchedResultField *ast.ResultField + if cn.Table.L != "" { + var matchedTable ast.ResultSetNode + for _, ts := range tableSources { + if cn.Table.L == ts.AsName.L { + // different table name. + matchedTable = ts + break + } + if tn, ok := ts.Source.(*ast.TableName); ok { + if cn.Table.L == tn.Name.L { + matchedTable = ts + } + } + } + if matchedTable != nil { + resultFields := matchedTable.GetResultFields() + for _, rf := range resultFields { + if rf.ColumnAsName.L == cn.Name.L || rf.Column.Name.L == cn.Name.L { + // bind column. + matchedResultField = rf + break + } + } + } + } else { + for _, ts := range tableSources { + rfs := ts.GetResultFields() + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return true + } + matchedResultField = rf + } + } + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +func (sb *InfoBinder) bindColumnInResultFields(cn *ast.ColumnName, rfs []*ast.ResultField) bool { + var matchedResultField *ast.ResultField + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return false + } + matchedResultField = rf + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +// handleFieldList expands wild card field and set fieldList in current context. +func (sb *InfoBinder) handleFieldList(fieldList *ast.FieldList) { + var resultFields []*ast.ResultField + for _, v := range fieldList.Fields { + resultFields = append(resultFields, sb.createResultFields(v)...) + } + sb.currentContext().fieldList = resultFields +} + +// createResultFields creates result field list for a single select field. +func (sb *InfoBinder) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) { + ctx := sb.currentContext() + if field.WildCard != nil { + if len(ctx.tables) == 0 { + sb.Err = errors.Errorf("No table used.") + return + } + if field.WildCard.Table.L == "" { + for _, v := range ctx.tables { + rfs = append(rfs, v.GetResultFields()...) + } + } else { + name := sb.tableUniqueName(field.WildCard.Schema, field.WildCard.Table) + tableIdx, ok := ctx.tableMap[name] + if !ok { + sb.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O) + } + rfs = ctx.tables[tableIdx].GetResultFields() + } + return + } + // The column is visited before so it must has been bound already. + rf := &ast.ResultField{ColumnAsName: field.AsName} + switch v := field.Expr.(type) { + case *ast.ColumnNameExpr: + rf.Column = v.Name.ColumnInfo + rf.Table = v.Name.TableInfo + rf.DBName = v.Name.Schema + default: + if field.AsName.L == "" { + rf.ColumnAsName.L = field.Expr.Text() + rf.ColumnAsName.O = rf.ColumnAsName.L + } + } + rfs = append(rfs, rf) + return +} + +func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) { + switch v := resultSetNode.(type) { + case *ast.TableSource: + out = append(in, v) + case *ast.Join: + out = appendTableSources(in, v.Left) + if v.Right != nil { + out = appendTableSources(out, v.Right) + } + } + return +} + +func (sb *InfoBinder) tableUniqueName(schema, table model.CIStr) string { + if schema.L != "" && schema.L != sb.DefaultSchema.L { + return schema.L + "." + table.L + } + return table.L +} diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go new file mode 100644 index 0000000000..f2109e0c17 --- /dev/null +++ b/optimizer/infobinder_test.go @@ -0,0 +1,69 @@ +package optimizer_test + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/testkit" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testInfoBinderSuite{}) + +type testInfoBinderSuite struct { +} + +type binderVerifier struct { + c *C +} + +func (bv *binderVerifier) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (bv *binderVerifier) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ColumnName: + bv.c.Assert(v.ColumnInfo, NotNil) + case *ast.TableName: + bv.c.Assert(v.TableInfo, NotNil) + } + return in, true +} + +func (ts *testInfoBinderSuite) TestInfoBinder(c *C) { + store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory) + c.Assert(err, IsNil) + defer store.Close() + testKit := testkit.NewTestKit(c, store) + testKit.MustExec("use test") + testKit.MustExec("create table t (c1 int, c2 int)") + domain := sessionctx.GetDomain(testKit.Se.(context.Context)) + + src := "SELECT c1 from t" + l := parser.NewLexer(src) + c.Assert(parser.YYParse(l), Equals, 0) + stmts := l.Stmts() + c.Assert(len(stmts), Equals, 1) + v := &optimizer.InfoBinder{ + Info: domain.InfoSchema(), + DefaultSchema: model.NewCIStr("test"), + } + selectStmt := stmts[0].(*ast.SelectStmt) + selectStmt.Accept(v) + + verifier := &binderVerifier{ + c: c, + } + selectStmt.Accept(verifier) +} From 0147eedd2c8822a6cc65c266f2098dbceb0f92a5 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 27 Oct 2015 16:05:46 +0800 Subject: [PATCH 02/27] optimiser: implement evaluator. --- ast/ast.go | 4 + ast/base.go | 11 + ast/expressions.go | 2 - ast/functions.go | 29 +- ast/parser/parser.y | 28 +- optimizer/aggregator.go | 13 + optimizer/evaluator.go | 657 ++++++++++++++++++++++++++++++++++++++ optimizer/optimizer.go | 25 +- optimizer/typecomputer.go | 16 + optimizer/validator.go | 16 + 10 files changed, 781 insertions(+), 20 deletions(-) create mode 100644 optimizer/aggregator.go create mode 100644 optimizer/evaluator.go create mode 100644 optimizer/typecomputer.go create mode 100644 optimizer/validator.go diff --git a/ast/ast.go b/ast/ast.go index dabc4fa912..fc505c8362 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -44,6 +44,10 @@ type ExprNode interface { SetType(tp *types.FieldType) // GetType gets the evaluation type of the expression. GetType() *types.FieldType + // SetValue sets value to the expression. + SetValue(val interface{}) + // GetValue gets value of the expression. + GetValue() interface{} } // FuncNode represents function call expression node. diff --git a/ast/base.go b/ast/base.go index 50e9ef030c..2b25bb7f96 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,6 +62,7 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node + val interface{} tp *types.FieldType } @@ -80,6 +81,16 @@ func (en *exprNode) GetType() *types.FieldType { return en.tp } +// SetValue implements Expression interface. +func (en *exprNode) SetValue(val interface{}) { + en.val = val +} + +// GetValue implements Expression interface. +func (en *exprNode) GetValue() interface{} { + return en.val +} + type funcNode struct { exprNode } diff --git a/ast/expressions.go b/ast/expressions.go index c99611ae13..f06109d312 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -47,8 +47,6 @@ var ( // ValueExpr is the simple value expression. type ValueExpr struct { exprNode - // Val is the literal value. - Val interface{} } // IsStatic implements ExprNode interface. diff --git a/ast/functions.go b/ast/functions.go index 4e052b3ccf..f3efdbae41 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -36,9 +36,6 @@ type FuncCallExpr struct { F string // Args is the function args. Args []ExprNode - // Distinct only affetcts sum, avg, count, group_concat, - // so we can ignore it in other functions - Distinct bool } // Accept implements Node interface. @@ -335,3 +332,29 @@ func (nod *FuncTrimExpr) IsStatic() bool { // TypeStar is a special type for "*". type TypeStar string + +type AggregateFuncExpr struct { + funcNode + // F is the function name. + F string + // Args is the function args. + Args []ExprNode + Distinct bool +} + +// Accept implements Node Accept interface. +func (nod *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*AggregateFuncExpr) + for i, val := range nod.Args { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Args[i] = node.(ExprNode) + } + return v.Leave(nod) +} \ No newline at end of file diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 97790fd0d5..9397163269 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1888,7 +1888,7 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: false} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { @@ -1918,18 +1918,14 @@ DistinctOpt: } FunctionCallKeyword: - "AVG" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "CAST" '(' Expression "AS" CastType ')' + "CAST" '(' Expression "AS" CastType ')' { /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ $$ = &ast.FuncCastExpr{ Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), FunctionType: ast.CastFunction, - } + } } | "CASE" ExpressionOpt WhenClauseList ElseOpt "END" { @@ -2211,30 +2207,34 @@ TrimDirection: } FunctionCallAgg: - "COUNT" '(' DistinctOpt ExpressionList ')' + "AVG" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} + } +| "COUNT" '(' DistinctOpt ExpressionList ')' + { + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt '*' ')' { args := []ast.ExprNode{&ast.ValueExpr{Val: ast.TypeStar("*")} } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "MAX" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "MIN" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "SUM" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } FuncDatetimePrec: diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go new file mode 100644 index 0000000000..95b9c7cb80 --- /dev/null +++ b/optimizer/aggregator.go @@ -0,0 +1,13 @@ +package optimizer + +// Aggregator is the interface to +// compute aggregate function result. +type Aggregator interface { + // Input add a input value to aggregator. + // The input values are accumulated in the aggregator. + Input(in []interface{}) error + // Output use input values to compute the aggregated result. + Output() interface{} + // Clear clears the input values. + Clear() +} diff --git a/optimizer/evaluator.go b/optimizer/evaluator.go new file mode 100644 index 0000000000..52dab1d6f2 --- /dev/null +++ b/optimizer/evaluator.go @@ -0,0 +1,657 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/plan" + "github.com/pingcap/tidb/util/types" +) + +type Evaluator struct { + // columnMap is the map from ColumnName to the position of the rowStack. + // It is used to find the value of the column. + columnMap map[*ast.ColumnName]position + // rowStack is the current row values while scanning. + // It should be updated after scaned a new row. + rowStack []*plan.Row + + // the map from AggregateFuncExpr to aggregator index. + aggregateMap map[*ast.AggregateFuncExpr]int + + // aggregators for the current row + // only outer query aggregate functions are handled. + aggregators []Aggregator + + // when aggregation phase is done, the input is + aggregateDone bool + + err error +} + +type position struct { + stackOffset int + fieldList bool + columnOffset int +} + +func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + ok = true + case *ast.AggregateFuncExpr: + ok = e.handleAggregateFunc(v) + case *ast.BetweenExpr: + ok = e.handleBetween(v) + case *ast.BinaryOperationExpr: + ok = e.handleBinaryOperation(v) + } + out = in + return +} + +func (e *Evaluator) handleAggregateFunc(v *ast.AggregateFuncExpr) bool { + idx := e.aggregateMap[v] + aggr := e.aggregators[idx] + if e.aggregateDone { + v.SetValue(aggr.Output()) + return true + } + // TODO: currently only single argument aggregate functions are supported. + e.err = aggr.Input(v.Args[0].GetValue()) + return e.err == nil +} + +func checkAllOneColumn(exprs ...ast.ExprNode) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *ast.RowExpr: + return false + case *ast.SubqueryExpr: + if len(v.Query.Fields.Fields) != 1 { + return false + } + } + } + return true +} + +func (e *Evaluator) handleBetween(v *ast.BetweenExpr) bool { + if !checkAllOneColumn(v.Expr, v.Left, v.Right) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + + var l, r ast.ExprNode + op := opcode.AndAnd + + if v.Not { + // v < lv || v > rv + op = opcode.OrOr + l = &ast.BinaryOperationExpr{opcode.LT, v.Expr, v.Left} + r = &ast.BinaryOperationExpr{opcode.GT, v.Expr, v.Right} + } else { + // v >= lv && v <= rv + l = &ast.BinaryOperationExpr{opcode.GE, v.Expr, v.Left} + r = &ast.BinaryOperationExpr{opcode.LE, v.Expr, v.Right} + } + + ret := &ast.BinaryOperationExpr{op, l, r} + ret.Accept(e) + return e.err == nil +} + +const ( + zeroI64 int64 = 0 + oneI64 int64 = 1 +) + +func (e *Evaluator) handleBinaryOperation(o *ast.BinaryOperationExpr) bool { + // all operands must have same column. + if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { + return nil, false + } + + // row constructor only supports comparison operation. + switch o.Op { + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + default: + if !checkAllOneColumn(o.L) { + return nil, errors.Errorf("Operand should contain 1 column(s)") + } + } + + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + + switch o.Op { + case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: + return e.handleLogicOperation(o) + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + return e.handleComparisonOp(o) + case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: + // TODO: MySQL doesn't support and not, we should remove it later. + return e.handleBitOp(o) + case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: + return e.handleArithmeticOp(o) + default: + panic("should never happen") + } +} + +func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + var boolVal bool + switch o.Op { + case opcode.AndAnd: + boolVal = leftVal != zeroI64 && rightVal != zeroI64 + case opcode.OrOr: + boolVal = leftVal != zeroI64 || rightVal != zeroI64 + case opcode.LogicXor: + boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) + default: + panic("should never happen") + } + if boolVal { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + if types.IsNil(a) || types.IsNil(b) { + // for <=>, if a and b are both nil, return true. + // if a or b is nil, return false. + if o.Op == opcode.NullEQ { + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + } else { + o.SetValue(nil) + } + return true + } + + n, err := types.Compare(a, b) + if err != nil { + e.err = err + return false + } + + r, err := getCompResult(o.Op, n) + if err != nil { + e.err = err + return false + } + if r { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + case opcode.NullEQ: + return value == 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + +func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + + if types.IsNil(a) || types.IsNil(b) { + return nil, nil + } + + x, err := types.ToInt64(a) + if err != nil { + e.err = err + return false + } + + y, err := types.ToInt64(b) + if err != nil { + e.err = err + return false + } + + // use a int64 for bit operator, return uint64 + switch o.Op { + case opcode.And: + o.SetValue(uint64(x & y)) + case opcode.Or: + o.SetValue(uint64(x | y)) + case opcode.Xor: + o.SetValue(uint64(x ^ y)) + case opcode.RightShift: + o.SetValue(uint64(x) >> uint64(y)) + case opcode.LeftShift: + o.SetValue(uint64(x) << uint64(y)) + default: + e.err = errors.Errorf("invalid op %v in bit operation", o.Op) + return false + } + return true +} + +func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { + a, err := coerceArithmetic(o.L.GetValue()) + if err != nil { + e.err = err + return false + } + + b, err := coerceArithmetic(o.R.GetValue()) + if err != nil { + e.err = err + return false + } + a, b = types.Coerce(a, b) + + if a == nil || b == nil { + // TODO: for <=>, if a and b are both nil, return true + o.SetValue(nil) + return true + } + + // TODO: support logic division DIV + var result interface{} + switch o.Op { + case opcode.Plus: + result, e.err = computePlus(a, b) + case opcode.Minus: + result, e.err = computeMinus(a, b) + case opcode.Mul: + result, e.err = computeMul(a, b) + case opcode.Div: + result, e.err = computeDiv(a, b) + case opcode.Mod: + result, e.err = computeMod(a, b) + case opcode.IntDiv: + result, e.err = computeIntDiv(a, b) + default: + e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) + return false + } + o.SetValue(result) + return e.err == nil +} + +func computePlus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.AddInt64(x, y) + case uint64: + return types.AddInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.AddInteger(x, y) + case uint64: + return types.AddUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x + y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Add(y), nil + } + } + return types.InvOp2(a, b, opcode.Plus) +} + +func computeMinus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.SubInt64(x, y) + case uint64: + return types.SubIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + return types.SubUintWithInt(x, y) + case uint64: + return types.SubUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x - y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Sub(y), nil + } + } + + return types.InvOp2(a, b, opcode.Minus) +} + +func computeMul(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.MulInt64(x, y) + case uint64: + return types.MulInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.MulInteger(x, y) + case uint64: + return types.MulUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x * y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Mul(y), nil + } + } + + return types.InvOp2(a, b, opcode.Mul) +} + +func computeDiv(a, b interface{}) (interface{}, error) { + // MySQL support integer divison Div and division operator / + // we use opcode.Div for division operator and will use another for integer division later. + // for division operator, we will use float64 for calculation. + switch x := a.(type) { + case float64: + y, err := types.ToFloat64(b) + if err != nil { + return nil, err + } + + if y == 0 { + return nil, nil + } + + return x / y, nil + default: + // the scale of the result is the scale of the first operand plus + // the value of the div_precision_increment system variable (which is 4 by default) + // we will use 4 here + + xa, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + xb, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + if f, _ := xb.Float64(); f == 0 { + // division by zero return null + return nil, nil + } + + return xa.Div(xb), nil + } +} + +func computeMod(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return x % y, nil + case uint64: + if y == 0 { + return nil, nil + } else if x < 0 { + // first is int64, return int64. + return -int64(uint64(-x) % y), nil + } + return int64(uint64(x) % y), nil + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } else if y < 0 { + // first is uint64, return uint64. + return uint64(x % uint64(-y)), nil + } + return x % uint64(y), nil + case uint64: + if y == 0 { + return nil, nil + } + return x % y, nil + } + case float64: + switch y := b.(type) { + case float64: + if y == 0 { + return nil, nil + } + return math.Mod(x, y), nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + xf, _ := x.Float64() + yf, _ := y.Float64() + if yf == 0 { + return nil, nil + } + return math.Mod(xf, yf), nil + } + } + + return types.InvOp2(a, b, opcode.Mod) +} + +func computeIntDiv(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivInt64(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return types.DivIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivUintWithInt(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return x / y, nil + } + } + + // if any is none integer, use decimal to calculate + x, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + y, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + + if f, _ := y.Float64(); f == 0 { + return nil, nil + } + + return x.Div(y).IntPart(), nil +} + +func coerceArithmetic(a interface{}) (interface{}, error) { + switch x := a.(type) { + case string: + // MySQL will convert string to float for arithmetic operation + f, err := types.StrToFloat(x) + if err != nil { + return nil, err + } + return f, err + case mysql.Time: + // if time has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case mysql.Duration: + // if duration has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case []byte: + // []byte is the same as string, converted to float for arithmetic operator. + f, err := types.StrToFloat(string(x)) + if err != nil { + return nil, err + } + return f, err + case mysql.Hex: + return x.ToNumber(), nil + case mysql.Bit: + return x.ToNumber(), nil + case mysql.Enum: + return x.ToNumber(), nil + case mysql.Set: + return x.ToNumber(), nil + default: + return x, nil + } +} + +func columnCount(e ast.ExprNode) (int, error) { + switch x := e.(type) { + case *ast.RowExpr: + n := len(x.Values) + if n <= 1 { + return 0, errors.Errorf("Operand should contain >= 2 columns for Row") + } + return n, nil + case *ast.SubqueryExpr: + return len(x.Query.Fields.Fields), nil + default: + return 1, nil + } +} + +func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error { + l, err := columnCount(e) + if err != nil { + return errors.Trace(err) + } + var n int + for _, arg := range args { + n, err = columnCount(arg) + if err != nil { + return errors.Trace(err) + } + + if n != l { + return errors.Errorf("Operand should contain %d column(s)", l) + } + } + return nil +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 36ccd06c5f..aad7c6d732 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -16,17 +16,40 @@ package optimizer import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/stmt" + "github.com/juju/errors" ) // Compile compiles a ast.Node into a executable statement. func Compile(node ast.Node) (stmt.Statement, error) { + validator := &validator{} + if _, ok := node.Accept(validator); !ok { + return nil, errors.Trace(validator.err) + } + + binder := &InfoBinder{} + if _, ok := node.Accept(validator); !ok { + return nil, errors.Trace(binder.Err) + } + + tpComputer := &typeComputer{} + if _, ok := node.Accept(typeComputer{}); !ok { + return nil, errors.Trace(tpComputer.err) + } + switch v := node.(type) { + case *ast.SelectStmt: + case *ast.SetStmt: return compileSet(v) } return nil, nil } -func compileSet(aset *ast.SetStmt) (stmt.Statement, error) { +func compileSelect(s *ast.SelectStmt) (stmt.Statement, error) { + + return nil, nil +} + +func compileSet(s *ast.SetStmt) (stmt.Statement, error) { return nil, nil } diff --git a/optimizer/typecomputer.go b/optimizer/typecomputer.go new file mode 100644 index 0000000000..3da52497bc --- /dev/null +++ b/optimizer/typecomputer.go @@ -0,0 +1,16 @@ +package optimizer +import "github.com/pingcap/tidb/ast" + +// typeComputer is an ast Visitor that +// Compute types for ast.ExprNode. +type typeComputer struct { + err error +} + +func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) { + return +} \ No newline at end of file diff --git a/optimizer/validator.go b/optimizer/validator.go new file mode 100644 index 0000000000..cd69ebd47a --- /dev/null +++ b/optimizer/validator.go @@ -0,0 +1,16 @@ +package optimizer +import "github.com/pingcap/tidb/ast" + +// validator is an ast.Visitor that validates +// ast parsed from parser. +type validator struct { + err error +} + +func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { + return +} \ No newline at end of file From f3e45697def4f5e9428ba440f0d5ff806f1398a3 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 28 Oct 2015 19:42:08 +0800 Subject: [PATCH 03/27] optimiser: convert to old statement and expression. --- ast/base.go | 11 +- ast/cloner_test.go | 6 +- ast/dml.go | 3 +- ast/expressions.go | 40 +- ast/functions.go | 11 +- ast/parser/parser.y | 22 +- optimizer/aggregator.go | 15 +- optimizer/convert_expr.go | 402 ++++++++++++++++++++ optimizer/convert_stmt.go | 562 +++++++++++++++++++++++++++ optimizer/evaluator.go | 718 +++++++++-------------------------- optimizer/evaluator_binop.go | 527 +++++++++++++++++++++++++ optimizer/infobinder_test.go | 13 + optimizer/optimizer.go | 65 +++- optimizer/typecomputer.go | 20 +- optimizer/validator.go | 20 +- 15 files changed, 1865 insertions(+), 570 deletions(-) create mode 100644 optimizer/convert_expr.go create mode 100644 optimizer/convert_stmt.go create mode 100644 optimizer/evaluator_binop.go diff --git a/ast/base.go b/ast/base.go index 2b25bb7f96..877fa80b51 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,8 +62,7 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node - val interface{} - tp *types.FieldType + types.DataItem } // IsStatic implements Expression interface. @@ -73,22 +72,22 @@ func (en *exprNode) IsStatic() bool { // SetType implements Expression interface. func (en *exprNode) SetType(tp *types.FieldType) { - en.tp = tp + en.Type = tp } // GetType implements Expression interface. func (en *exprNode) GetType() *types.FieldType { - return en.tp + return en.Type } // SetValue implements Expression interface. func (en *exprNode) SetValue(val interface{}) { - en.val = val + en.Data = val } // GetValue implements Expression interface. func (en *exprNode) GetValue() interface{} { - return en.val + return en.Data } type funcNode struct { diff --git a/ast/cloner_test.go b/ast/cloner_test.go index 186ee824a7..f6f567205b 100644 --- a/ast/cloner_test.go +++ b/ast/cloner_test.go @@ -21,7 +21,7 @@ func (ts *testClonerSuite) TestCloner(c *C) { a := &UnaryOperationExpr{ Op: opcode.Not, - V: &UnaryOperationExpr{V: &ValueExpr{Val: true}}, + V: &UnaryOperationExpr{V: NewValueExpr(true)}, } b, ok := a.Accept(cloner) @@ -35,6 +35,6 @@ func (ts *testClonerSuite) TestCloner(c *C) { a3 := a2.(*ValueExpr) b3 := b2.(*ValueExpr) c.Assert(a3, Not(Equals), b3) - c.Assert(a3.Val, Equals, true) - c.Assert(b3.Val, Equals, true) + c.Assert(a3.GetValue(), Equals, true) + c.Assert(b3.GetValue(), Equals, true) } diff --git a/ast/dml.go b/ast/dml.go index f7f11421bd..d46cd86394 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -22,6 +22,7 @@ var ( _ DMLNode = &DeleteStmt{} _ DMLNode = &UpdateStmt{} _ DMLNode = &SelectStmt{} + _ DMLNode = &UnionStmt{} _ Node = &Join{} _ Node = &TableName{} _ Node = &TableSource{} @@ -114,7 +115,7 @@ type TableSource struct { node // Source is the source of the data, can be a TableName, - // a SubQuery, or a JoinNode. + // a SelectStmt, a UnionStmt, or a JoinNode. Source ResultSetNode // AsName is the as name of the table source. diff --git a/ast/expressions.go b/ast/expressions.go index f06109d312..43fc176498 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -15,7 +15,9 @@ package ast import ( "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" ) var ( @@ -26,6 +28,7 @@ var ( _ ExprNode = &CaseExpr{} _ ExprNode = &SubqueryExpr{} _ ExprNode = &CompareSubqueryExpr{} + _ Node = &ColumnName{} _ ExprNode = &ColumnNameExpr{} _ ExprNode = &DefaultExpr{} _ ExprNode = &IdentifierExpr{} @@ -49,6 +52,39 @@ type ValueExpr struct { exprNode } +// NewValueExpr creates a ValueExpr with value, and sets default field type. +func NewValueExpr(value interface{}) *ValueExpr { + ve := &ValueExpr{} + ve.Data = value + // TODO: make it more precise. + switch value.(type) { + case bool, int64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + case uint64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + ve.Type.Flag |= mysql.UnsignedFlag + case string: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = mysql.DefaultCharset + ve.Type.Collate = mysql.DefaultCollationName + case float64: + ve.Type = types.NewFieldType(mysql.TypeDouble) + case []byte: + ve.Type = types.NewFieldType(mysql.TypeBlob) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + case mysql.Bit: + ve.Type = types.NewFieldType(mysql.TypeBit) + case mysql.Hex: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + default: + panic("illegal literal value type") + } + return ve +} + // IsStatic implements ExprNode interface. func (nod *ValueExpr) IsStatic() bool { return true @@ -247,7 +283,7 @@ func (nod *CaseExpr) IsStatic() bool { type SubqueryExpr struct { exprNode // Query is the query SelectNode. - Query *SelectStmt + Query ResultSetNode } // Accept implements Node Accept interface. @@ -261,7 +297,7 @@ func (nod *SubqueryExpr) Accept(v Visitor) (Node, bool) { if !ok { return nod, false } - nod.Query = node.(*SelectStmt) + nod.Query = node.(ResultSetNode) return v.Leave(nod) } diff --git a/ast/functions.go b/ast/functions.go index f3efdbae41..97b9cd1e3c 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -26,7 +26,9 @@ var ( _ FuncNode = &FuncConvertExpr{} _ FuncNode = &FuncCastExpr{} _ FuncNode = &FuncSubstringExpr{} + _ FuncNode = &FuncLocateExpr{} _ FuncNode = &FuncTrimExpr{} + _ FuncNode = &AggregateFuncExpr{} ) // FuncCallExpr is for function expression. @@ -330,15 +332,16 @@ func (nod *FuncTrimExpr) IsStatic() bool { return nod.Str.IsStatic() && nod.RemStr.IsStatic() } -// TypeStar is a special type for "*". -type TypeStar string - +// AggregateFuncExpr represents aggregate function expression. type AggregateFuncExpr struct { funcNode // F is the function name. F string // Args is the function args. Args []ExprNode + // If distinct is true, the function only aggregate distinct values. + // For example, column c1 values are "1", "2", "2", "sum(c1)" is "5", + // but "sum(distinct c1)" is "3". Distinct bool } @@ -357,4 +360,4 @@ func (nod *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { nod.Args[i] = node.(ExprNode) } return v.Leave(nod) -} \ No newline at end of file +} diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 9397163269..af69b11b61 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -944,15 +944,15 @@ NowSym: SignedLiteral: Literal { - $$ = ast.ValueExpr{Val: $1} + $$ = ast.NewValueExpr($1) } | '+' NumLiteral { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: &ast.ValueExpr{Val: $2}} + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: ast.NewValueExpr($2)} } | '-' NumLiteral { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: &ast.ValueExpr{Val: $2}} + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr($2)} } // TODO: support decimal literal @@ -1752,7 +1752,7 @@ Literal: Operand: Literal { - $$ = &ast.ValueExpr{Val: $1} + $$ = ast.NewValueExpr($1) } | ColumnName { @@ -2217,7 +2217,7 @@ FunctionCallAgg: } | "COUNT" '(' DistinctOpt '*' ')' { - args := []ast.ExprNode{&ast.ValueExpr{Val: ast.TypeStar("*")} } + args := []ast.ExprNode{ast.NewValueExpr("*")} $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' @@ -2649,6 +2649,10 @@ TableFactor: { $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } +| '(' UnionStmt ')' TableAsName + { + $$ = &ast.TableSource{Source: $2.(*ast.UnionStmt), AsName: $4.(model.CIStr)} + } | '(' TableRefs ')' { $$ = $2 @@ -2794,6 +2798,14 @@ SubSelect: s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) $$ = &ast.SubqueryExpr{Query: s} } +| '(' UnionStmt ')' + { + s := $2.(*ast.UnionStmt) + src := yylex.(*lexer).src + // See the implemention of yyParse function + s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) + $$ = &ast.SubqueryExpr{Query: s} + } // See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html SelectLockOpt: diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go index 95b9c7cb80..29b2fa3ea3 100644 --- a/optimizer/aggregator.go +++ b/optimizer/aggregator.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer // Aggregator is the interface to @@ -5,7 +18,7 @@ package optimizer type Aggregator interface { // Input add a input value to aggregator. // The input values are accumulated in the aggregator. - Input(in []interface{}) error + Input(in ...interface{}) error // Output use input values to compute the aggregated result. Output() interface{} // Clear clears the input values. diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go new file mode 100644 index 0000000000..4fe74ebf34 --- /dev/null +++ b/optimizer/convert_expr.go @@ -0,0 +1,402 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/subquery" + "github.com/pingcap/tidb/model" + "strings" +) + +func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression.Expression, error) { + expr.Accept(converter) + if converter.err != nil { + return nil, errors.Trace(converter.err) + } + return converter.exprMap[expr], nil +} + +// expressionConverter converts ast expression to +// old expression for transition state. +type expressionConverter struct { + exprMap map[ast.Node]expression.Expression + err error +} + +func newExpressionConverter() *expressionConverter { + return &expressionConverter{ + exprMap: map[ast.Node]expression.Expression{}, + } +} + +// Enter implements ast.Visitor interface. +func (c *expressionConverter) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +// Leave implements ast.Visitor interface. +func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + c.value(v) + case *ast.BetweenExpr: + c.between(v) + case *ast.BinaryOperationExpr: + c.binaryOperation(v) + case *ast.WhenClause: + c.whenClause(v) + case *ast.CaseExpr: + c.caseExpr(v) + case *ast.SubqueryExpr: + c.subquery(v) + case *ast.CompareSubqueryExpr: + c.compareSubquery(v) + case *ast.ColumnNameExpr: + c.columnNameExpr(v) + case *ast.DefaultExpr: + c.defaultExpr(v) + case *ast.IdentifierExpr: + c.identifier(v) + case *ast.ExistsSubqueryExpr: + c.existsSubquery(v) + case *ast.PatternInExpr: + c.patternIn(v) + case *ast.IsNullExpr: + c.isNull(v) + case *ast.IsTruthExpr: + c.isTruth(v) + case *ast.PatternLikeExpr: + c.patternLike(v) + case *ast.ParamMarkerExpr: + c.paramMarker(v) + case *ast.ParenthesesExpr: + c.parentheses(v) + case *ast.PositionExpr: + c.position(v) + case *ast.PatternRegexpExpr: + c.patternRegexp(v) + case *ast.RowExpr: + c.row(v) + case *ast.UnaryOperationExpr: + c.unaryOperation(v) + case *ast.ValuesExpr: + c.values(v) + case *ast.VariableExpr: + c.variable(v) + case *ast.FuncCallExpr: + c.funcCall(v) + case *ast.FuncExtractExpr: + c.funcExtract(v) + case *ast.FuncConvertExpr: + c.funcConvert(v) + case *ast.FuncCastExpr: + c.funcCast(v) + case *ast.FuncSubstringExpr: + c.funcSubstring(v) + case *ast.FuncLocateExpr: + c.funcLocate(v) + case *ast.FuncTrimExpr: + c.funcTrim(v) + case *ast.AggregateFuncExpr: + c.aggregateFunc(v) + } + return in, c.err == nil +} + +func (c *expressionConverter) value(v *ast.ValueExpr) { + c.exprMap[v] = expression.Value{Val: v.GetValue()} +} + +func (c *expressionConverter) between(v *ast.BetweenExpr) { + oldExpr := c.exprMap[v.Expr] + oldLo := c.exprMap[v.Left] + oldHi := c.exprMap[v.Right] + oldBetween, err := expression.NewBetween(oldExpr, oldLo, oldHi, v.Not) + if err != nil { + c.err = err + return + } + c.exprMap[v] = oldBetween +} + +func (c *expressionConverter) binaryOperation(v *ast.BinaryOperationExpr) { + oldeLeft := c.exprMap[v.L] + oldRight := c.exprMap[v.R] + oldBinop := expression.NewBinaryOperation(v.Op, oldeLeft, oldRight) + c.exprMap[v] = oldBinop +} + +func (c *expressionConverter) whenClause(v *ast.WhenClause) { + oldExpr := c.exprMap[v.Expr] + oldResult := c.exprMap[v.Result] + oldWhenClause := &expression.WhenClause{Expr: oldExpr, Result: oldResult} + c.exprMap[v] = oldWhenClause +} + +func (c *expressionConverter) caseExpr(v *ast.CaseExpr) { + oldValue := c.exprMap[v.Value] + oldWhenClauses := make([]*expression.WhenClause, len(v.WhenClauses)) + for i, val := range v.WhenClauses { + oldWhenClauses[i] = c.exprMap[val].(*expression.WhenClause) + } + oldElse := c.exprMap[v.ElseClause] + oldCaseExpr := &expression.FunctionCase{ + Value: oldValue, + WhenClauses: oldWhenClauses, + ElseClause: oldElse, + } + c.exprMap[v] = oldCaseExpr +} + +func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { + oldSubquery := &subquery.SubQuery{} + switch x := v.Query.(type) { + case *ast.SelectStmt: + oldSelect, err := convertSelect(x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldUnion + } + c.exprMap[v] = oldSubquery +} + +func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { + expr := c.exprMap[v.L] + subquery := c.exprMap[v.R] + oldCmpSubquery := expression.NewCompareSubQuery(v.Op, expr, subquery.(expression.SubQuery), v.All) + c.exprMap[v] = oldCmpSubquery +} + +func concatCIStr(ciStrs ...model.CIStr) string { + var originStrs []string + for _, v := range ciStrs { + if v.O != "" { + originStrs = append(originStrs, v.O) + } + } + return strings.Join(originStrs, ".") +} + +func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { + ident := &expression.Ident{} + ident.CIStr = model.NewCIStr(concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name)) + c.exprMap[v] = ident +} + +func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { + oldDefault := &expression.Default{} + if v.Name != nil { + oldDefault.Name = concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name) + } + c.exprMap[v] = oldDefault +} + +func (c *expressionConverter) identifier(v *ast.IdentifierExpr) { + oldIdent := &expression.Ident{} + oldIdent.CIStr = v.Name + c.exprMap[v] = oldIdent +} + +func (c *expressionConverter) existsSubquery(v *ast.ExistsSubqueryExpr) { + subquery := c.exprMap[v.Sel].(expression.SubQuery) + c.exprMap[v] = &expression.ExistsSubQuery{Sel: subquery} +} + +func (c *expressionConverter) patternIn(v *ast.PatternInExpr) { + oldPatternIn := &expression.PatternIn{Not: v.Not} + if v.Sel != nil { + oldPatternIn.Sel = c.exprMap[v.Sel].(expression.SubQuery) + } + oldPatternIn.Expr = c.exprMap[v.Expr] + if v.List != nil { + oldPatternIn.List = make([]expression.Expression, len(v.List)) + for i, v := range v.List { + oldPatternIn.List[i] = c.exprMap[v] + } + } + c.exprMap[v] = oldPatternIn +} + +func (c *expressionConverter) isNull(v *ast.IsNullExpr) { + oldIsNull := &expression.IsNull{Not: v.Not} + oldIsNull.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsNull +} + +func (c *expressionConverter) isTruth(v *ast.IsTruthExpr) { + oldIsTruth := &expression.IsTruth{Not: v.Not, True: v.True} + oldIsTruth.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsTruth +} + +func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { + oldPatternLike := &expression.PatternLike{ + Not: v.Not, + Escape: v.Escape, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternLike +} + +func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { + c.exprMap[v] = &expression.ParamMarker{} +} + +func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { + oldExpr := c.exprMap[v.Expr] + c.exprMap[v] = &expression.PExpr{Expr: oldExpr} +} + +func (c *expressionConverter) position(v *ast.PositionExpr) { + c.exprMap[v] = &expression.Position{N: v.N, Name: v.Name} +} + +func (c *expressionConverter) patternRegexp(v *ast.PatternRegexpExpr) { + oldPatternRegexp := &expression.PatternRegexp{ + Not: v.Not, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternRegexp +} + +func (c *expressionConverter) row(v *ast.RowExpr) { + oldRow := &expression.Row{} + oldRow.Values = make([]expression.Expression, len(v.Values)) + for i, val := range v.Values { + oldRow.Values[i] = c.exprMap[val] + } + c.exprMap[v] = oldRow +} + +func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { + oldUnary := &expression.UnaryOperation{ + Op: v.Op, + V: c.exprMap[v.V], + } + c.exprMap[v] = oldUnary +} + +func (c *expressionConverter) values(v *ast.ValuesExpr) { + nameStr := concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} +} + +func (c *expressionConverter) variable(v *ast.VariableExpr) { + c.exprMap[v] = &expression.Variable{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + } +} + +func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) { + oldCall := &expression.Call{ + F: v.F, + } + oldCall.Args = make([]expression.Expression, len(v.Args)) + for i, val := range v.Args { + oldCall.Args[i] = c.exprMap[val] + } + c.exprMap[v] = oldCall +} + +func (c *expressionConverter) funcExtract(v *ast.FuncExtractExpr) { + oldExtract := &expression.Extract{Unit: v.Unit} + oldExtract.Date = c.exprMap[v.Date] + c.exprMap[v] = oldExtract +} + +func (c *expressionConverter) funcConvert(v *ast.FuncConvertExpr) { + c.exprMap[v] = &expression.FunctionConvert{ + Expr: c.exprMap[v.Expr], + Charset: v.Charset, + } +} + +func (c *expressionConverter) funcCast(v *ast.FuncCastExpr) { + oldCast := &expression.FunctionCast{ + Expr: c.exprMap[v.Expr], + Tp: v.Tp, + } + switch v.FunctionType { + case ast.CastBinaryOperator: + oldCast.FunctionType = expression.BinaryOperator + case ast.CastConvertFunction: + oldCast.FunctionType = expression.ConvertFunction + case ast.CastFunction: + oldCast.FunctionType = expression.CastFunction + } + c.exprMap[v] = oldCast +} + +func (c *expressionConverter) funcSubstring(v *ast.FuncSubstringExpr) { + oldSubstring := &expression.FunctionSubstring{ + Len: c.exprMap[v.Len], + Pos: c.exprMap[v.Pos], + StrExpr: c.exprMap[v.StrExpr], + } + c.exprMap[v] = oldSubstring +} + +func (c *expressionConverter) funcLocate(v *ast.FuncLocateExpr) { + oldLocate := &expression.FunctionLocate{ + Pos: c.exprMap[v.Pos], + Str: c.exprMap[v.Str], + SubStr: c.exprMap[v.SubStr], + } + c.exprMap[v] = oldLocate +} + +func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) { + oldTrim := &expression.FunctionTrim{ + Str: c.exprMap[v.Str], + RemStr: c.exprMap[v.RemStr], + } + switch v.Direction { + case ast.TrimBoth: + oldTrim.Direction = expression.TrimBoth + case ast.TrimBothDefault: + oldTrim.Direction = expression.TrimBothDefault + case ast.TrimLeading: + oldTrim.Direction = expression.TrimLeading + case ast.TrimTrailing: + oldTrim.Direction = expression.TrimTrailing + } + c.exprMap[v] = oldTrim +} + +func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) { + oldAggregate := &expression.Call{ + F: v.F, + Distinct: v.Distinct, + } + for _, val := range v.Args { + oldAggregate.Args = append(oldAggregate.Args, c.exprMap[val]) + } + c.exprMap[v] = oldAggregate +} diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go new file mode 100644 index 0000000000..34e81052b9 --- /dev/null +++ b/optimizer/convert_stmt.go @@ -0,0 +1,562 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/parser/coldef" + "github.com/pingcap/tidb/rset/rsets" + "github.com/pingcap/tidb/stmt" + "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/table" +) + +func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { + oldInsert := &stmts.InsertIntoStmt{ + Text: v.Text(), + } + return oldInsert, nil +} + +func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { + oldDelete := &stmts.DeleteStmt{ + Text: v.Text(), + } + return oldDelete, nil +} + +func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { + oldUpdate := &stmts.UpdateStmt{ + Text: v.Text(), + } + return oldUpdate, nil +} + +func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { + converter := &expressionConverter{ + exprMap: map[ast.Node]expression.Expression{}, + } + oldSelect := &stmts.SelectStmt{} + oldSelect.Distinct = s.Distinct + oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) + for i, val := range s.Fields.Fields { + oldField := &field.Field{} + oldField.AsName = val.AsName.O + var err error + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldSelect.Fields[i] = oldField + } + var err error + if s.From != nil { + oldSelect.From, err = convertJoin(converter, s.From.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Where != nil { + oldSelect.Where = &rsets.WhereRset{} + oldSelect.Where.Expr, err = convertExpr(converter, s.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + + if s.GroupBy != nil { + oldSelect.GroupBy, err = convertGroupBy(converter, s.GroupBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Having != nil { + oldSelect.Having, err = convertHaving(converter, s.Having) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.OrderBy != nil { + oldSelect.OrderBy, err = convertOrderBy(converter, s.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Limit != nil { + if s.Limit.Offset > 0 { + oldSelect.Offset = &rsets.OffsetRset{Count: s.Limit.Offset} + } + if s.Limit.Count > 0 { + oldSelect.Limit = &rsets.LimitRset{Count: s.Limit.Count} + } + } + switch s.LockTp { + case ast.SelectLockForUpdate: + oldSelect.Lock = coldef.SelectLockForUpdate + case ast.SelectLockInShareMode: + oldSelect.Lock = coldef.SelectLockInShareMode + case ast.SelectLockNone: + oldSelect.Lock = coldef.SelectLockNone + } + return oldSelect, nil +} + +func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { + oldUnion := &stmts.UnionStmt{} + oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) + oldUnion.Distincts = make([]bool, len(u.Selects)-1) + if u.Distinct { + for i := range oldUnion.Distincts { + oldUnion.Distincts[i] = true + } + } + for i, val := range u.Selects { + oldSelect, err := convertSelect(val) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.Selects[i] = oldSelect + } + if u.Limit != nil { + if u.Limit.Offset > 0 { + oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} + } + if u.Limit.Count > 0 { + oldUnion.Limit = &rsets.LimitRset{Count: u.Limit.Count} + } + } + // Union order by can not converted to old because it is pushed to select statements. + return oldUnion, nil +} + +func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRset, error) { + oldJoin := &rsets.JoinRset{} + switch join.Tp { + case ast.CrossJoin: + oldJoin.Type = rsets.CrossJoin + case ast.LeftJoin: + oldJoin.Type = rsets.LeftJoin + case ast.RightJoin: + oldJoin.Type = rsets.RightJoin + } + switch l := join.Left.(type) { + case *ast.Join: + oldLeft, err := convertJoin(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + case *ast.TableSource: + oldLeft, err := convertTableSource(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + } + + switch r := join.Left.(type) { + case *ast.Join: + oldRight, err := convertJoin(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case *ast.TableSource: + oldRight, err := convertTableSource(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case nil: + } + + if join.On != nil { + oldOn, err := convertExpr(converter, join.On.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.On = oldOn + } + return oldJoin, nil +} + +func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*rsets.TableSource, error) { + oldTs := &rsets.TableSource{} + oldTs.Name = ts.AsName.O + switch src := ts.Source.(type) { + case *ast.TableName: + oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} + case *ast.SelectStmt: + oldSelect, err := convertSelect(src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldUnion + } + return oldTs, nil +} + +func convertGroupBy(converter *expressionConverter, gb *ast.GroupByClause) (*rsets.GroupByRset, error) { + oldGroupBy := &rsets.GroupByRset{ + By: make([]expression.Expression, len(gb.Items)), + } + for i, val := range gb.Items { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldGroupBy.By[i] = oldExpr + } + return oldGroupBy, nil +} + +func convertHaving(converter *expressionConverter, having *ast.HavingClause) (*rsets.HavingRset, error) { + oldHaving := &rsets.HavingRset{} + oldExpr, err := convertExpr(converter, having.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldHaving.Expr = oldExpr + return oldHaving, nil +} + +func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) (*rsets.OrderByRset, error) { + oldOrderBy := &rsets.OrderByRset{} + oldOrderBy.By = make([]rsets.OrderByItem, len(orderBy.Items)) + for i, val := range orderBy.Items { + oldByItemExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldOrderBy.By[i].Expr = oldByItemExpr + oldOrderBy.By[i].Asc = !val.Desc + } + return oldOrderBy, nil +} + +func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { + oldCreateDatabase := &stmts.CreateDatabaseStmt{ + IfNotExists: v.IfNotExists, + Name: v.Name, + Text: v.Text(), + } + if len(v.Options) != 0 { + oldCreateDatabase.Opt = &coldef.CharsetOpt{} + for _, val := range v.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + oldCreateDatabase.Opt.Chs = val.Value + case ast.DatabaseOptionCollate: + oldCreateDatabase.Opt.Col = val.Value + } + } + } + return oldCreateDatabase, nil +} + +func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { + return &stmts.DropDatabaseStmt{ + IfExists: v.IfExists, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { + oldCreateTable := &stmts.CreateTableStmt{ + Text: v.Text(), + } + return oldCreateTable, nil +} + +func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { + oldDropTable := &stmts.DropTableStmt{ + IfExists: v.IfExists, + Text: v.Text(), + } + oldDropTable.TableIdents = make([]table.Ident, len(v.Tables)) + for i, val := range v.Tables { + oldDropTable.TableIdents[i] = table.Ident{ + Schema: val.Schema, + Name: val.Name, + } + } + return oldDropTable, nil +} + +func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { + oldCreateIndex := &stmts.CreateIndexStmt{ + IndexName: v.IndexName, + Unique: v.Unique, + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + } + oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) + for i, val := range v.IndexColNames { + oldIndexColName := &coldef.IndexColName{ + ColumnName: concatCIStr(val.Column.Schema, val.Column.Table, val.Column.Name), + Length: val.Length, + } + oldCreateIndex.IndexColNames[i] = oldIndexColName + } + return oldCreateIndex, nil +} + +func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { + return &stmts.DropIndexStmt{ + IfExists: v.IfExists, + IndexName: v.IndexName, + Text: v.Text(), + }, nil +} + +func convertAlterTable(v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { + oldAlterTable := &stmts.AlterTableStmt{ + Text: v.Text(), + } + return oldAlterTable, nil +} + +func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { + return &stmts.TruncateTableStmt{ + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + }, nil +} + +func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { + oldExplain := &stmts.ExplainStmt{ + Text: v.Text(), + } + var err error + switch x := v.Stmt.(type) { + case *ast.SelectStmt: + oldExplain.S, err = convertSelect(x) + case *ast.UpdateStmt: + oldExplain.S, err = convertUpdate(x) + case *ast.DeleteStmt: + oldExplain.S, err = convertDelete(x) + case *ast.InsertStmt: + oldExplain.S, err = convertInsert(x) + } + if err != nil { + return nil, errors.Trace(err) + } + return oldExplain, nil +} + +func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { + oldPrepare := &stmts.PreparedStmt{ + InPrepare: true, + Name: v.Name, + SQLText: v.SQLText, + Text: v.Text(), + } + if v.SQLVar != nil { + converter := newExpressionConverter() + oldSQLVar, err := convertExpr(converter, v.SQLVar) + if err != nil { + return nil, errors.Trace(err) + } + oldPrepare.SQLVar = oldSQLVar.(*expression.Variable) + } + return oldPrepare, nil +} + +func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { + return &stmts.DeallocateStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { + oldExec := &stmts.ExecuteStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + } + oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) + converter := newExpressionConverter() + for i, val := range v.UsingVars { + oldVar, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldExec.UsingVars[i] = oldVar + } + return oldExec, nil +} + +func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { + oldShow := &stmts.ShowStmt{ + DBName: v.DBName, + Flag: v.Flag, + Full: v.Full, + GlobalScope: v.GlobalScope, + Text: v.Text(), + } + if v.Table != nil { + oldShow.TableIdent = table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + } + } + if v.Column != nil { + oldShow.ColumnName = concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + } + if v.Where != nil { + converter := newExpressionConverter() + oldWhere, err := convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Where = oldWhere + } + if v.Pattern != nil { + converter := newExpressionConverter() + oldPattern, err := convertExpr(converter, v.Pattern) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Pattern = oldPattern.(*expression.PatternLike) + } + switch v.Tp { + case ast.ShowCharset: + oldShow.Target = stmt.ShowCharset + case ast.ShowCollation: + oldShow.Target = stmt.ShowCollation + case ast.ShowColumns: + oldShow.Target = stmt.ShowColumns + case ast.ShowCreateTable: + oldShow.Target = stmt.ShowCreateTable + case ast.ShowDatabases: + oldShow.Target = stmt.ShowDatabases + case ast.ShowTables: + oldShow.Target = stmt.ShowTables + case ast.ShowEngines: + oldShow.Target = stmt.ShowEngines + case ast.ShowVariables: + oldShow.Target = stmt.ShowVariables + case ast.ShowWarnings: + oldShow.Target = stmt.ShowWarnings + case ast.ShowNone: + oldShow.Target = stmt.ShowNone + } + return oldShow, nil +} + +func convertBegin(v *ast.BeginStmt) (*stmts.BeginStmt, error) { + return &stmts.BeginStmt{ + Text: v.Text(), + }, nil +} + +func convertCommit(v *ast.CommitStmt) (*stmts.CommitStmt, error) { + return &stmts.CommitStmt{ + Text: v.Text(), + }, nil +} + +func convertRollback(v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { + return &stmts.RollbackStmt{ + Text: v.Text(), + }, nil +} + +func convertUse(v *ast.UseStmt) (*stmts.UseStmt, error) { + return &stmts.UseStmt{ + DBName: v.DBName, + Text: v.Text(), + }, nil +} + +func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAssignment) (*stmts.VariableAssignment, error) { + oldValue, err := convertExpr(converter, v.Value) + if err != nil { + return nil, errors.Trace(err) + } + + return &stmts.VariableAssignment{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + Value: oldValue, + Text: v.Text(), + }, nil +} + +func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { + oldSet := &stmts.SetStmt{ + Text: v.Text(), + Variables: make([]*stmts.VariableAssignment, len(v.Variables)), + } + converter := newExpressionConverter() + for i, val := range v.Variables { + oldAssign, err := convertVariableAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldSet.Variables[i] = oldAssign + } + return oldSet, nil +} + +func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { + return &stmts.SetCharsetStmt{ + Charset: v.Charset, + Collate: v.Collate, + Text: v.Text(), + }, nil +} + +func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { + return &stmts.SetPwdStmt{ + User: v.User, + Password: v.Password, + Text: v.Text(), + }, nil +} + +func convertDo(v *ast.DoStmt) (*stmts.DoStmt, error) { + exprConverter := newExpressionConverter() + oldDo := &stmts.DoStmt{ + Text: v.Text(), + Exprs: make([]expression.Expression, len(v.Exprs)), + } + for i, val := range v.Exprs { + oldExpr, err := convertExpr(exprConverter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldDo.Exprs[i] = oldExpr + } + return oldDo, nil +} diff --git a/optimizer/evaluator.go b/optimizer/evaluator.go index 52dab1d6f2..fffb5b53b3 100644 --- a/optimizer/evaluator.go +++ b/optimizer/evaluator.go @@ -14,16 +14,13 @@ package optimizer import ( - "math" - "github.com/juju/errors" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan" - "github.com/pingcap/tidb/util/types" ) +// Evaluator is a ast visitor that evaluates an expression. type Evaluator struct { // columnMap is the map from ColumnName to the position of the rowStack. // It is used to find the value of the column. @@ -51,44 +48,88 @@ type position struct { columnOffset int } +// Enter implements ast.Visitor interface. func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { return } +// Leave implements ast.Visitor interface. func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { switch v := in.(type) { case *ast.ValueExpr: ok = true - case *ast.AggregateFuncExpr: - ok = e.handleAggregateFunc(v) case *ast.BetweenExpr: - ok = e.handleBetween(v) + ok = e.between(v) case *ast.BinaryOperationExpr: - ok = e.handleBinaryOperation(v) + ok = e.binaryOperation(v) + case *ast.WhenClause: + ok = e.whenClause(v) + case *ast.CaseExpr: + ok = e.caseExpr(v) + case *ast.SubqueryExpr: + ok = e.subquery(v) + case *ast.CompareSubqueryExpr: + ok = e.compareSubquery(v) + case *ast.ColumnNameExpr: + ok = e.columnName(v) + case *ast.DefaultExpr: + ok = e.defaultExpr(v) + case *ast.IdentifierExpr: + ok = e.identifier(v) + case *ast.ExistsSubqueryExpr: + ok = e.existsSubquery(v) + case *ast.PatternInExpr: + ok = e.patternIn(v) + case *ast.IsNullExpr: + ok = e.isNull(v) + case *ast.IsTruthExpr: + ok = e.isTruth(v) + case *ast.PatternLikeExpr: + ok = e.patternLike(v) + case *ast.ParamMarkerExpr: + ok = e.paramMarker(v) + case *ast.ParenthesesExpr: + ok = e.parentheses(v) + case *ast.PositionExpr: + ok = e.position(v) + case *ast.PatternRegexpExpr: + ok = e.patternRegexp(v) + case *ast.RowExpr: + ok = e.row(v) + case *ast.UnaryOperationExpr: + ok = e.unaryOperation(v) + case *ast.ValuesExpr: + ok = e.values(v) + case *ast.VariableExpr: + ok = e.variable(v) + case *ast.FuncCallExpr: + ok = e.funcCall(v) + case *ast.FuncExtractExpr: + ok = e.funcExtract(v) + case *ast.FuncConvertExpr: + ok = e.funcConvert(v) + case *ast.FuncCastExpr: + ok = e.funcCast(v) + case *ast.FuncSubstringExpr: + ok = e.funcSubstring(v) + case *ast.FuncLocateExpr: + ok = e.funcLocate(v) + case *ast.FuncTrimExpr: + ok = e.funcTrim(v) + case *ast.AggregateFuncExpr: + ok = e.aggregateFunc(v) } out = in return } -func (e *Evaluator) handleAggregateFunc(v *ast.AggregateFuncExpr) bool { - idx := e.aggregateMap[v] - aggr := e.aggregators[idx] - if e.aggregateDone { - v.SetValue(aggr.Output()) - return true - } - // TODO: currently only single argument aggregate functions are supported. - e.err = aggr.Input(v.Args[0].GetValue()) - return e.err == nil -} - func checkAllOneColumn(exprs ...ast.ExprNode) bool { for _, expr := range exprs { switch v := expr.(type) { case *ast.RowExpr: return false case *ast.SubqueryExpr: - if len(v.Query.Fields.Fields) != 1 { + if len(v.Query.GetResultFields()) != 1 { return false } } @@ -96,7 +137,7 @@ func checkAllOneColumn(exprs ...ast.ExprNode) bool { return true } -func (e *Evaluator) handleBetween(v *ast.BetweenExpr) bool { +func (e *Evaluator) between(v *ast.BetweenExpr) bool { if !checkAllOneColumn(v.Expr, v.Left, v.Right) { e.err = errors.Errorf("Operand should contain 1 column(s)") return false @@ -108,520 +149,19 @@ func (e *Evaluator) handleBetween(v *ast.BetweenExpr) bool { if v.Not { // v < lv || v > rv op = opcode.OrOr - l = &ast.BinaryOperationExpr{opcode.LT, v.Expr, v.Left} - r = &ast.BinaryOperationExpr{opcode.GT, v.Expr, v.Right} + l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right} } else { // v >= lv && v <= rv - l = &ast.BinaryOperationExpr{opcode.GE, v.Expr, v.Left} - r = &ast.BinaryOperationExpr{opcode.LE, v.Expr, v.Right} + l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right} } - ret := &ast.BinaryOperationExpr{op, l, r} + ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r} ret.Accept(e) return e.err == nil } -const ( - zeroI64 int64 = 0 - oneI64 int64 = 1 -) - -func (e *Evaluator) handleBinaryOperation(o *ast.BinaryOperationExpr) bool { - // all operands must have same column. - if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { - return nil, false - } - - // row constructor only supports comparison operation. - switch o.Op { - case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: - default: - if !checkAllOneColumn(o.L) { - return nil, errors.Errorf("Operand should contain 1 column(s)") - } - } - - leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - if leftVal == nil || rightVal == nil { - o.SetValue(nil) - return true - } - - switch o.Op { - case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: - return e.handleLogicOperation(o) - case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: - return e.handleComparisonOp(o) - case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: - // TODO: MySQL doesn't support and not, we should remove it later. - return e.handleBitOp(o) - case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: - return e.handleArithmeticOp(o) - default: - panic("should never happen") - } -} - -func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { - leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) - if err != nil { - e.err = err - return false - } - if leftVal == nil || rightVal == nil { - o.SetValue(nil) - return true - } - var boolVal bool - switch o.Op { - case opcode.AndAnd: - boolVal = leftVal != zeroI64 && rightVal != zeroI64 - case opcode.OrOr: - boolVal = leftVal != zeroI64 || rightVal != zeroI64 - case opcode.LogicXor: - boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) - default: - panic("should never happen") - } - if boolVal { - o.SetValue(oneI64) - } else { - o.SetValue(zeroI64) - } - return true -} - -func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { - a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) - if types.IsNil(a) || types.IsNil(b) { - // for <=>, if a and b are both nil, return true. - // if a or b is nil, return false. - if o.Op == opcode.NullEQ { - if types.IsNil(a) || types.IsNil(b) { - o.SetValue(oneI64) - } else { - o.SetValue(zeroI64) - } - } else { - o.SetValue(nil) - } - return true - } - - n, err := types.Compare(a, b) - if err != nil { - e.err = err - return false - } - - r, err := getCompResult(o.Op, n) - if err != nil { - e.err = err - return false - } - if r { - o.SetValue(oneI64) - } else { - o.SetValue(zeroI64) - } - return true -} - -func getCompResult(op opcode.Op, value int) (bool, error) { - switch op { - case opcode.LT: - return value < 0, nil - case opcode.LE: - return value <= 0, nil - case opcode.GE: - return value >= 0, nil - case opcode.GT: - return value > 0, nil - case opcode.EQ: - return value == 0, nil - case opcode.NE: - return value != 0, nil - case opcode.NullEQ: - return value == 0, nil - default: - return false, errors.Errorf("invalid op %v in comparision operation", op) - } -} - -func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { - a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) - - if types.IsNil(a) || types.IsNil(b) { - return nil, nil - } - - x, err := types.ToInt64(a) - if err != nil { - e.err = err - return false - } - - y, err := types.ToInt64(b) - if err != nil { - e.err = err - return false - } - - // use a int64 for bit operator, return uint64 - switch o.Op { - case opcode.And: - o.SetValue(uint64(x & y)) - case opcode.Or: - o.SetValue(uint64(x | y)) - case opcode.Xor: - o.SetValue(uint64(x ^ y)) - case opcode.RightShift: - o.SetValue(uint64(x) >> uint64(y)) - case opcode.LeftShift: - o.SetValue(uint64(x) << uint64(y)) - default: - e.err = errors.Errorf("invalid op %v in bit operation", o.Op) - return false - } - return true -} - -func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { - a, err := coerceArithmetic(o.L.GetValue()) - if err != nil { - e.err = err - return false - } - - b, err := coerceArithmetic(o.R.GetValue()) - if err != nil { - e.err = err - return false - } - a, b = types.Coerce(a, b) - - if a == nil || b == nil { - // TODO: for <=>, if a and b are both nil, return true - o.SetValue(nil) - return true - } - - // TODO: support logic division DIV - var result interface{} - switch o.Op { - case opcode.Plus: - result, e.err = computePlus(a, b) - case opcode.Minus: - result, e.err = computeMinus(a, b) - case opcode.Mul: - result, e.err = computeMul(a, b) - case opcode.Div: - result, e.err = computeDiv(a, b) - case opcode.Mod: - result, e.err = computeMod(a, b) - case opcode.IntDiv: - result, e.err = computeIntDiv(a, b) - default: - e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) - return false - } - o.SetValue(result) - return e.err == nil -} - -func computePlus(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - return types.AddInt64(x, y) - case uint64: - return types.AddInteger(y, x) - } - case uint64: - switch y := b.(type) { - case int64: - return types.AddInteger(x, y) - case uint64: - return types.AddUint64(x, y) - } - case float64: - switch y := b.(type) { - case float64: - return x + y, nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Add(y), nil - } - } - return types.InvOp2(a, b, opcode.Plus) -} - -func computeMinus(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - return types.SubInt64(x, y) - case uint64: - return types.SubIntWithUint(x, y) - } - case uint64: - switch y := b.(type) { - case int64: - return types.SubUintWithInt(x, y) - case uint64: - return types.SubUint64(x, y) - } - case float64: - switch y := b.(type) { - case float64: - return x - y, nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Sub(y), nil - } - } - - return types.InvOp2(a, b, opcode.Minus) -} - -func computeMul(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - return types.MulInt64(x, y) - case uint64: - return types.MulInteger(y, x) - } - case uint64: - switch y := b.(type) { - case int64: - return types.MulInteger(x, y) - case uint64: - return types.MulUint64(x, y) - } - case float64: - switch y := b.(type) { - case float64: - return x * y, nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - return x.Mul(y), nil - } - } - - return types.InvOp2(a, b, opcode.Mul) -} - -func computeDiv(a, b interface{}) (interface{}, error) { - // MySQL support integer divison Div and division operator / - // we use opcode.Div for division operator and will use another for integer division later. - // for division operator, we will use float64 for calculation. - switch x := a.(type) { - case float64: - y, err := types.ToFloat64(b) - if err != nil { - return nil, err - } - - if y == 0 { - return nil, nil - } - - return x / y, nil - default: - // the scale of the result is the scale of the first operand plus - // the value of the div_precision_increment system variable (which is 4 by default) - // we will use 4 here - - xa, err := types.ToDecimal(a) - if err != nil { - return nil, err - } - - xb, err := types.ToDecimal(b) - if err != nil { - return nil, err - } - if f, _ := xb.Float64(); f == 0 { - // division by zero return null - return nil, nil - } - - return xa.Div(xb), nil - } -} - -func computeMod(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } - return x % y, nil - case uint64: - if y == 0 { - return nil, nil - } else if x < 0 { - // first is int64, return int64. - return -int64(uint64(-x) % y), nil - } - return int64(uint64(x) % y), nil - } - case uint64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } else if y < 0 { - // first is uint64, return uint64. - return uint64(x % uint64(-y)), nil - } - return x % uint64(y), nil - case uint64: - if y == 0 { - return nil, nil - } - return x % y, nil - } - case float64: - switch y := b.(type) { - case float64: - if y == 0 { - return nil, nil - } - return math.Mod(x, y), nil - } - case mysql.Decimal: - switch y := b.(type) { - case mysql.Decimal: - xf, _ := x.Float64() - yf, _ := y.Float64() - if yf == 0 { - return nil, nil - } - return math.Mod(xf, yf), nil - } - } - - return types.InvOp2(a, b, opcode.Mod) -} - -func computeIntDiv(a, b interface{}) (interface{}, error) { - switch x := a.(type) { - case int64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } - return types.DivInt64(x, y) - case uint64: - if y == 0 { - return nil, nil - } - return types.DivIntWithUint(x, y) - } - case uint64: - switch y := b.(type) { - case int64: - if y == 0 { - return nil, nil - } - return types.DivUintWithInt(x, y) - case uint64: - if y == 0 { - return nil, nil - } - return x / y, nil - } - } - - // if any is none integer, use decimal to calculate - x, err := types.ToDecimal(a) - if err != nil { - return nil, err - } - - y, err := types.ToDecimal(b) - if err != nil { - return nil, err - } - - if f, _ := y.Float64(); f == 0 { - return nil, nil - } - - return x.Div(y).IntPart(), nil -} - -func coerceArithmetic(a interface{}) (interface{}, error) { - switch x := a.(type) { - case string: - // MySQL will convert string to float for arithmetic operation - f, err := types.StrToFloat(x) - if err != nil { - return nil, err - } - return f, err - case mysql.Time: - // if time has no precision, return int64 - v := x.ToNumber() - if x.Fsp == 0 { - return v.IntPart(), nil - } - return v, nil - case mysql.Duration: - // if duration has no precision, return int64 - v := x.ToNumber() - if x.Fsp == 0 { - return v.IntPart(), nil - } - return v, nil - case []byte: - // []byte is the same as string, converted to float for arithmetic operator. - f, err := types.StrToFloat(string(x)) - if err != nil { - return nil, err - } - return f, err - case mysql.Hex: - return x.ToNumber(), nil - case mysql.Bit: - return x.ToNumber(), nil - case mysql.Enum: - return x.ToNumber(), nil - case mysql.Set: - return x.ToNumber(), nil - default: - return x, nil - } -} - func columnCount(e ast.ExprNode) (int, error) { switch x := e.(type) { case *ast.RowExpr: @@ -631,7 +171,7 @@ func columnCount(e ast.ExprNode) (int, error) { } return n, nil case *ast.SubqueryExpr: - return len(x.Query.Fields.Fields), nil + return len(x.Query.GetResultFields()), nil default: return 1, nil } @@ -655,3 +195,123 @@ func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error { } return nil } + +func (e *Evaluator) whenClause(v *ast.WhenClause) bool { + return true +} + +func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { + return true +} + +func (e *Evaluator) subquery(v *ast.SubqueryExpr) bool { + return true +} + +func (e *Evaluator) compareSubquery(v *ast.CompareSubqueryExpr) bool { + return true +} + +func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { + return true +} + +func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool { + return true +} + +func (e *Evaluator) identifier(v *ast.IdentifierExpr) bool { + return true +} + +func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { + return true +} + +func (e *Evaluator) patternIn(v *ast.PatternInExpr) bool { + return true +} + +func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { + return true +} + +func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { + return true +} + +func (e *Evaluator) patternLike(v *ast.PatternLikeExpr) bool { + return true +} + +func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { + return true +} + +func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { + return true +} + +func (e *Evaluator) position(v *ast.PositionExpr) bool { + return true +} + +func (e *Evaluator) patternRegexp(v *ast.PatternRegexpExpr) bool { + return true +} + +func (e *Evaluator) row(v *ast.RowExpr) bool { + return true +} + +func (e *Evaluator) unaryOperation(v *ast.UnaryOperationExpr) bool { + return true +} + +func (e *Evaluator) values(v *ast.ValuesExpr) bool { + return true +} + +func (e *Evaluator) variable(v *ast.VariableExpr) bool { + return true +} + +func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool { + return true +} + +func (e *Evaluator) funcExtract(v *ast.FuncExtractExpr) bool { + return true +} + +func (e *Evaluator) funcConvert(v *ast.FuncConvertExpr) bool { + return true +} + +func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool { + return true +} + +func (e *Evaluator) funcSubstring(v *ast.FuncSubstringExpr) bool { + return true +} + +func (e *Evaluator) funcLocate(v *ast.FuncLocateExpr) bool { + return true +} + +func (e *Evaluator) funcTrim(v *ast.FuncTrimExpr) bool { + return true +} + +func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { + idx := e.aggregateMap[v] + aggr := e.aggregators[idx] + if e.aggregateDone { + v.SetValue(aggr.Output()) + return true + } + // TODO: currently only single argument aggregate functions are supported. + e.err = aggr.Input(v.Args[0].GetValue()) + return e.err == nil +} diff --git a/optimizer/evaluator_binop.go b/optimizer/evaluator_binop.go new file mode 100644 index 0000000000..0cda326144 --- /dev/null +++ b/optimizer/evaluator_binop.go @@ -0,0 +1,527 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +const ( + zeroI64 int64 = 0 + oneI64 int64 = 1 +) + +func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool { + // all operands must have same column. + if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { + return false + } + + // row constructor only supports comparison operation. + switch o.Op { + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + default: + if !checkAllOneColumn(o.L) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + } + + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + + switch o.Op { + case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: + return e.handleLogicOperation(o) + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + return e.handleComparisonOp(o) + case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: + // TODO: MySQL doesn't support and not, we should remove it later. + return e.handleBitOp(o) + case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: + return e.handleArithmeticOp(o) + default: + panic("should never happen") + } +} + +func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + var boolVal bool + switch o.Op { + case opcode.AndAnd: + boolVal = leftVal != zeroI64 && rightVal != zeroI64 + case opcode.OrOr: + boolVal = leftVal != zeroI64 || rightVal != zeroI64 + case opcode.LogicXor: + boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) + default: + panic("should never happen") + } + if boolVal { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + if types.IsNil(a) || types.IsNil(b) { + // for <=>, if a and b are both nil, return true. + // if a or b is nil, return false. + if o.Op == opcode.NullEQ { + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + } else { + o.SetValue(nil) + } + return true + } + + n, err := types.Compare(a, b) + if err != nil { + e.err = err + return false + } + + r, err := getCompResult(o.Op, n) + if err != nil { + e.err = err + return false + } + if r { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + case opcode.NullEQ: + return value == 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + +func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(nil) + return true + } + + x, err := types.ToInt64(a) + if err != nil { + e.err = err + return false + } + + y, err := types.ToInt64(b) + if err != nil { + e.err = err + return false + } + + // use a int64 for bit operator, return uint64 + switch o.Op { + case opcode.And: + o.SetValue(uint64(x & y)) + case opcode.Or: + o.SetValue(uint64(x | y)) + case opcode.Xor: + o.SetValue(uint64(x ^ y)) + case opcode.RightShift: + o.SetValue(uint64(x) >> uint64(y)) + case opcode.LeftShift: + o.SetValue(uint64(x) << uint64(y)) + default: + e.err = errors.Errorf("invalid op %v in bit operation", o.Op) + return false + } + return true +} + +func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { + a, err := coerceArithmetic(o.L.GetValue()) + if err != nil { + e.err = err + return false + } + + b, err := coerceArithmetic(o.R.GetValue()) + if err != nil { + e.err = err + return false + } + a, b = types.Coerce(a, b) + + if a == nil || b == nil { + // TODO: for <=>, if a and b are both nil, return true + o.SetValue(nil) + return true + } + + // TODO: support logic division DIV + var result interface{} + switch o.Op { + case opcode.Plus: + result, e.err = computePlus(a, b) + case opcode.Minus: + result, e.err = computeMinus(a, b) + case opcode.Mul: + result, e.err = computeMul(a, b) + case opcode.Div: + result, e.err = computeDiv(a, b) + case opcode.Mod: + result, e.err = computeMod(a, b) + case opcode.IntDiv: + result, e.err = computeIntDiv(a, b) + default: + e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) + return false + } + o.SetValue(result) + return e.err == nil +} + +func computePlus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.AddInt64(x, y) + case uint64: + return types.AddInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.AddInteger(x, y) + case uint64: + return types.AddUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x + y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Add(y), nil + } + } + return types.InvOp2(a, b, opcode.Plus) +} + +func computeMinus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.SubInt64(x, y) + case uint64: + return types.SubIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + return types.SubUintWithInt(x, y) + case uint64: + return types.SubUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x - y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Sub(y), nil + } + } + + return types.InvOp2(a, b, opcode.Minus) +} + +func computeMul(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.MulInt64(x, y) + case uint64: + return types.MulInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.MulInteger(x, y) + case uint64: + return types.MulUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x * y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Mul(y), nil + } + } + + return types.InvOp2(a, b, opcode.Mul) +} + +func computeDiv(a, b interface{}) (interface{}, error) { + // MySQL support integer divison Div and division operator / + // we use opcode.Div for division operator and will use another for integer division later. + // for division operator, we will use float64 for calculation. + switch x := a.(type) { + case float64: + y, err := types.ToFloat64(b) + if err != nil { + return nil, err + } + + if y == 0 { + return nil, nil + } + + return x / y, nil + default: + // the scale of the result is the scale of the first operand plus + // the value of the div_precision_increment system variable (which is 4 by default) + // we will use 4 here + + xa, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + xb, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + if f, _ := xb.Float64(); f == 0 { + // division by zero return null + return nil, nil + } + + return xa.Div(xb), nil + } +} + +func computeMod(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return x % y, nil + case uint64: + if y == 0 { + return nil, nil + } else if x < 0 { + // first is int64, return int64. + return -int64(uint64(-x) % y), nil + } + return int64(uint64(x) % y), nil + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } else if y < 0 { + // first is uint64, return uint64. + return uint64(x % uint64(-y)), nil + } + return x % uint64(y), nil + case uint64: + if y == 0 { + return nil, nil + } + return x % y, nil + } + case float64: + switch y := b.(type) { + case float64: + if y == 0 { + return nil, nil + } + return math.Mod(x, y), nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + xf, _ := x.Float64() + yf, _ := y.Float64() + if yf == 0 { + return nil, nil + } + return math.Mod(xf, yf), nil + } + } + + return types.InvOp2(a, b, opcode.Mod) +} + +func computeIntDiv(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivInt64(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return types.DivIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivUintWithInt(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return x / y, nil + } + } + + // if any is none integer, use decimal to calculate + x, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + y, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + + if f, _ := y.Float64(); f == 0 { + return nil, nil + } + + return x.Div(y).IntPart(), nil +} + +func coerceArithmetic(a interface{}) (interface{}, error) { + switch x := a.(type) { + case string: + // MySQL will convert string to float for arithmetic operation + f, err := types.StrToFloat(x) + if err != nil { + return nil, err + } + return f, err + case mysql.Time: + // if time has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case mysql.Duration: + // if duration has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case []byte: + // []byte is the same as string, converted to float for arithmetic operator. + f, err := types.StrToFloat(string(x)) + if err != nil { + return nil, err + } + return f, err + case mysql.Hex: + return x.ToNumber(), nil + case mysql.Bit: + return x.ToNumber(), nil + case mysql.Enum: + return x.ToNumber(), nil + case mysql.Set: + return x.ToNumber(), nil + default: + return x, nil + } +} diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go index f2109e0c17..988d6dfaa0 100644 --- a/optimizer/infobinder_test.go +++ b/optimizer/infobinder_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer_test import ( diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index aad7c6d732..cec754d164 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -14,9 +14,9 @@ package optimizer import ( + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/stmt" - "github.com/juju/errors" ) // Compile compiles a ast.Node into a executable statement. @@ -32,24 +32,63 @@ func Compile(node ast.Node) (stmt.Statement, error) { } tpComputer := &typeComputer{} - if _, ok := node.Accept(typeComputer{}); !ok { + if _, ok := node.Accept(tpComputer); !ok { return nil, errors.Trace(tpComputer.err) } switch v := node.(type) { + case *ast.InsertStmt: + return convertInsert(v) + case *ast.DeleteStmt: + return convertDelete(v) + case *ast.UpdateStmt: + return convertUpdate(v) case *ast.SelectStmt: - + return convertSelect(v) + case *ast.UnionStmt: + return convertUnion(v) + case *ast.CreateDatabaseStmt: + return convertCreateDatabase(v) + case *ast.DropDatabaseStmt: + return convertDropDatabase(v) + case *ast.CreateTableStmt: + return convertCreateTable(v) + case *ast.DropTableStmt: + return convertDropTable(v) + case *ast.CreateIndexStmt: + return convertCreateIndex(v) + case *ast.DropIndexStmt: + return convertDropIndex(v) + case *ast.AlterTableStmt: + return convertAlterTable(v) + case *ast.TruncateTableStmt: + return convertTruncateTable(v) + case *ast.ExplainStmt: + return convertExplain(v) + case *ast.PrepareStmt: + return convertPrepare(v) + case *ast.DeallocateStmt: + return convertDeallocate(v) + case *ast.ExecuteStmt: + return convertExecute(v) + case *ast.ShowStmt: + return convertShow(v) + case *ast.BeginStmt: + return convertBegin(v) + case *ast.CommitStmt: + return convertCommit(v) + case *ast.RollbackStmt: + return convertRollback(v) + case *ast.UseStmt: + return convertUse(v) case *ast.SetStmt: - return compileSet(v) + return convertSet(v) + case *ast.SetCharsetStmt: + return convertSetCharset(v) + case *ast.SetPwdStmt: + return convertSetPwd(v) + case *ast.DoStmt: + return convertDo(v) } return nil, nil } - -func compileSelect(s *ast.SelectStmt) (stmt.Statement, error) { - - return nil, nil -} - -func compileSet(s *ast.SetStmt) (stmt.Statement, error) { - return nil, nil -} diff --git a/optimizer/typecomputer.go b/optimizer/typecomputer.go index 3da52497bc..d1862d1380 100644 --- a/optimizer/typecomputer.go +++ b/optimizer/typecomputer.go @@ -1,4 +1,18 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer + import "github.com/pingcap/tidb/ast" // typeComputer is an ast Visitor that @@ -8,9 +22,9 @@ type typeComputer struct { } func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { - return + return in, false } func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) { - return -} \ No newline at end of file + return in, true +} diff --git a/optimizer/validator.go b/optimizer/validator.go index cd69ebd47a..f90a757ab8 100644 --- a/optimizer/validator.go +++ b/optimizer/validator.go @@ -1,4 +1,18 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package optimizer + import "github.com/pingcap/tidb/ast" // validator is an ast.Visitor that validates @@ -8,9 +22,9 @@ type validator struct { } func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { - return + return in, false } func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { - return -} \ No newline at end of file + return in, true +} From edcbbc84ff83f811e4708fadeabb74b14fdcb0e1 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 29 Oct 2015 15:33:54 +0800 Subject: [PATCH 04/27] optimizer: implement ddl statement convert functions. --- ast/dml.go | 4 +- ast/parser/parser.y | 4 + optimizer/convert_expr.go | 18 +-- optimizer/convert_stmt.go | 274 +++++++++++++++++++++++++++++++++++++- 4 files changed, 284 insertions(+), 16 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index d46cd86394..0225af0462 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -564,7 +564,7 @@ type InsertStmt struct { Setlist []*Assignment Priority int OnDuplicate []*Assignment - Select *SelectStmt + Select ResultSetNode } // Accept implements Node Accept interface. @@ -616,7 +616,7 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { if !ok { return nod, false } - nod.Select = node.(*SelectStmt) + nod.Select = node.(ResultSetNode) } return v.Leave(nod) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index af69b11b61..6f096b8687 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1669,6 +1669,10 @@ InsertRest: { $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} } +| UnionStmt + { + $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} + } | "SET" ColumnSetValueList { $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 4fe74ebf34..41c78a815e 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -190,26 +190,28 @@ func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { c.exprMap[v] = oldCmpSubquery } -func concatCIStr(ciStrs ...model.CIStr) string { +func joinColumnName(columnName *ast.ColumnName) string { var originStrs []string - for _, v := range ciStrs { - if v.O != "" { - originStrs = append(originStrs, v.O) - } + if columnName.Schema.O != "" { + originStrs = append(originStrs, columnName.Schema.O) } + if columnName.Table.O != "" { + originStrs = append(originStrs, columnName.Table.O) + } + originStrs = append(originStrs, columnName.Name.O) return strings.Join(originStrs, ".") } func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { ident := &expression.Ident{} - ident.CIStr = model.NewCIStr(concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name)) + ident.CIStr = model.NewCIStr(joinColumnName(v.Name)) c.exprMap[v] = ident } func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { oldDefault := &expression.Default{} if v.Name != nil { - oldDefault.Name = concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name) + oldDefault.Name = joinColumnName(v.Name) } c.exprMap[v] = oldDefault } @@ -302,7 +304,7 @@ func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { } func (c *expressionConverter) values(v *ast.ValuesExpr) { - nameStr := concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + nameStr := joinColumnName(v.Column) c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} } diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 34e81052b9..036bfebefd 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -25,23 +25,153 @@ import ( "github.com/pingcap/tidb/table" ) +func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expression.Assignment, error) { + oldAssign := &expression.Assignment{ + ColName: joinColumnName(v.Column), + } + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldAssign.Expr = oldExpr + return oldAssign, nil +} + func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { oldInsert := &stmts.InsertIntoStmt{ - Text: v.Text(), + Priority: v.Priority, + Text: v.Text(), + } + tableName := v.Table.TableRefs.Left.(*ast.TableName) + oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + for _, val := range v.Columns { + oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) + } + converter := newExpressionConverter() + for _, row := range v.Lists { + var oldRow []expression.Expression + for _, val := range row { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldRow = append(oldRow, oldExpr) + } + oldInsert.Lists = append(oldInsert.Lists, oldRow) + } + for _, assign := range v.Setlist { + oldAssign, err := convertAssignment(converter, assign) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.Setlist = append(oldInsert.Setlist, oldAssign) + } + for _, onDup := range v.OnDuplicate { + oldOnDup, err := convertAssignment(converter, onDup) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, *oldOnDup) + } + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + oldInsert.Sel, err = convertSelect(x) + case *ast.UnionStmt: + oldInsert.Sel, err = convertUnion(x) + } + if err != nil { + return nil, errors.Trace(err) + } } return oldInsert, nil } func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { oldDelete := &stmts.DeleteStmt{ - Text: v.Text(), + BeforeFrom: v.BeforeFrom, + Ignore: v.Ignore, + LowPriority: v.LowPriority, + MultiTable: v.MultiTable, + Quick: v.Quick, + Text: v.Text(), + } + converter := newExpressionConverter() + oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Refs = oldRefs + for _, val := range v.Tables { + tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} + oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + } + if v.Where != nil { + oldDelete.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Order != nil { + orderRset := &rsets.OrderByRset{} + for _, val := range v.Order { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} + orderRset.By = append(orderRset.By, orderItem) + } + oldDelete.Order = orderRset + } + if v.Limit != nil { + oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} } return oldDelete, nil } func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { oldUpdate := &stmts.UpdateStmt{ - Text: v.Text(), + Ignore: v.Ignore, + MultipleTable: v.MultipleTable, + LowPriority: v.LowPriority, + Text: v.Text(), + } + converter := newExpressionConverter() + var err error + oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + if v.Where != nil { + oldUpdate.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + for _, val := range v.List { + var oldAssign *expression.Assignment + oldAssign, err = convertAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.List = append(oldUpdate.List, *oldAssign) + } + if v.Order != nil { + orderRset := &rsets.OrderByRset{} + for _, val := range v.Order { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} + orderRset.By = append(orderRset.By, orderItem) + } + oldUpdate.Order = orderRset + } + if v.Limit != nil { + oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} } return oldUpdate, nil } @@ -282,9 +412,141 @@ func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, erro }, nil } +func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (*coldef.ConstraintOpt, error) { + oldColumnOpt := &coldef.ConstraintOpt{} + switch v.Tp { + case ast.ColumnOptionAutoIncrement: + oldColumnOpt.Tp = coldef.ConstrAutoIncrement + case ast.ColumnOptionComment: + oldColumnOpt.Tp = coldef.ConstrComment + case ast.ColumnOptionDefaultValue: + oldColumnOpt.Tp = coldef.ConstrDefaultValue + case ast.ColumnOptionIndex: + oldColumnOpt.Tp = coldef.ConstrIndex + case ast.ColumnOptionKey: + oldColumnOpt.Tp = coldef.ConstrKey + case ast.ColumnOptionFulltext: + oldColumnOpt.Tp = coldef.ConstrFulltext + case ast.ColumnOptionNotNull: + oldColumnOpt.Tp = coldef.ConstrNotNull + case ast.ColumnOptionNoOption: + oldColumnOpt.Tp = coldef.ConstrNoConstr + case ast.ColumnOptionOnUpdate: + oldColumnOpt.Tp = coldef.ConstrOnUpdate + case ast.ColumnOptionPrimaryKey: + oldColumnOpt.Tp = coldef.ConstrPrimaryKey + case ast.ColumnOptionNull: + oldColumnOpt.Tp = coldef.ConstrNull + case ast.ColumnOptionUniq: + oldColumnOpt.Tp = coldef.ConstrUniq + case ast.ColumnOptionUniqIndex: + oldColumnOpt.Tp = coldef.ConstrUniqIndex + case ast.ColumnOptionUniqKey: + oldColumnOpt.Tp = coldef.ConstrUniqKey + } + if v.Expr != nil { + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldColumnOpt.Evalue = oldExpr + } + return oldColumnOpt, nil +} + +func convertColumnDef(converter *expressionConverter, v *ast.ColumnDef) (*coldef.ColumnDef, error) { + oldColDef := &coldef.ColumnDef{ + Name: v.Name.Name.O, + Tp: v.Tp, + } + for _, val := range v.Options { + oldOpt, err := convertColumnOption(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldColDef.Constraints = append(oldColDef.Constraints, oldOpt) + } + return oldColDef, nil +} + +func convertIndexColNames(v []*ast.IndexColName) (out []*coldef.IndexColName) { + for _, val := range v { + oldIndexColKey := &coldef.IndexColName{ + ColumnName: val.Column.Name.O, + Length: val.Length, + } + out = append(out, oldIndexColKey) + } + return +} + +func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*coldef.TableConstraint, error) { + oldConstraint := &coldef.TableConstraint{ConstrName: v.Name} + switch v.Tp { + case ast.ConstraintNoConstraint: + oldConstraint.Tp = coldef.ConstrNoConstr + case ast.ConstraintPrimaryKey: + oldConstraint.Tp = coldef.ConstrPrimaryKey + case ast.ConstraintKey: + oldConstraint.Tp = coldef.ConstrKey + case ast.ConstraintIndex: + oldConstraint.Tp = coldef.ConstrIndex + case ast.ConstraintUniq: + oldConstraint.Tp = coldef.ConstrUniq + case ast.ConstraintUniqKey: + oldConstraint.Tp = coldef.ConstrUniqKey + case ast.ConstraintUniqIndex: + oldConstraint.Tp = coldef.ConstrUniqIndex + case ast.ConstraintForeignKey: + oldConstraint.Tp = coldef.ConstrForeignKey + case ast.ConstraintFulltext: + oldConstraint.Tp = coldef.ConstrFulltext + } + oldConstraint.Keys = convertIndexColNames(v.Keys) + if v.Refer != nil { + oldConstraint.Refer = &coldef.ReferenceDef{ + TableIdent: table.Ident{Schema: v.Refer.Table.Schema, Name: v.Refer.Table.Name}, + IndexColNames: convertIndexColNames(v.Refer.IndexColNames), + } + } + return oldConstraint, nil +} + func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { oldCreateTable := &stmts.CreateTableStmt{ - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + converter := newExpressionConverter() + for _, val := range v.Cols { + oldColDef, err := convertColumnDef(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Cols = append(oldCreateTable.Cols, oldColDef) + } + for _, val := range v.Constraints { + oldConstr, err := convertConstraint(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Constraints = append(oldCreateTable.Constraints, oldConstr) + } + if len(v.Options) != 0 { + oldTableOpt := &coldef.TableOption{} + for _, val := range v.Options { + switch val.Tp { + case ast.TableOptionEngine: + oldTableOpt.Engine = val.StrValue + case ast.TableOptionCharset: + oldTableOpt.Charset = val.StrValue + case ast.TableOptionCollate: + oldTableOpt.Collate = val.StrValue + case ast.TableOptionAutoIncrement: + oldTableOpt.AutoIncrement = val.UintValue + } + } + oldCreateTable.Opt = oldTableOpt } return oldCreateTable, nil } @@ -317,7 +579,7 @@ func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) for i, val := range v.IndexColNames { oldIndexColName := &coldef.IndexColName{ - ColumnName: concatCIStr(val.Column.Schema, val.Column.Table, val.Column.Name), + ColumnName: joinColumnName(val.Column), Length: val.Length, } oldCreateIndex.IndexColNames[i] = oldIndexColName @@ -430,7 +692,7 @@ func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { } } if v.Column != nil { - oldShow.ColumnName = concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + oldShow.ColumnName = joinColumnName(v.Column) } if v.Where != nil { converter := newExpressionConverter() From 7c7473969f99aa21143bd4a674a0a69de480a12c Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 29 Oct 2015 20:29:02 +0800 Subject: [PATCH 05/27] tidb: switch to use ast parser. --- ast/cloner.go | 7 +- ast/dml.go | 8 ++ ast/expressions.go | 5 +- ast/parser/parser.y | 101 +++++++++--------- ast/parser/scanner.l | 13 ++- optimizer/convert_expr.go | 4 +- optimizer/convert_stmt.go | 208 ++++++++++++++++++++++++++++---------- optimizer/optimizer.go | 80 +++++++++------ tidb.go | 23 +++-- tidb_test.go | 1 + 10 files changed, 297 insertions(+), 153 deletions(-) diff --git a/ast/cloner.go b/ast/cloner.go index f08ef09476..bf7b3c961d 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -13,6 +13,8 @@ package ast +import "fmt" + // Cloner is a ast visitor that clones a node. type Cloner struct { } @@ -57,6 +59,9 @@ func cloneStruct(in Node) (out Node) { case *ColumnName: nv := *v out = &nv + case *ColumnNameExpr: + nv := *v + out = &nv case *DefaultExpr: nv := *v out = &nv @@ -162,7 +167,7 @@ func cloneStruct(in Node) (out Node) { default: // We currently only handle expression and select statement. // Will add more when we need to. - panic("unknown ast Node type") + panic("unknown ast Node type " + fmt.Sprintf("%T", v)) } return } diff --git a/ast/dml.go b/ast/dml.go index 0225af0462..b19e60bfc6 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -488,6 +488,7 @@ type UnionStmt struct { Distinct bool Selects []*SelectStmt + OrderBy *OrderByClause Limit *Limit } @@ -505,6 +506,13 @@ func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { } nod.Selects[i] = node.(*SelectStmt) } + if nod.OrderBy != nil { + node, ok := nod.OrderBy.Accept(v) + if !ok { + return nod, false + } + nod.OrderBy = node.(*OrderByClause) + } if nod.Limit != nil { node, ok := nod.Limit.Accept(v) if !ok { diff --git a/ast/expressions.go b/ast/expressions.go index 43fc176498..e32b14752a 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -14,6 +14,7 @@ package ast import ( + "fmt" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" @@ -58,6 +59,8 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Data = value // TODO: make it more precise. switch value.(type) { + case nil: + ve.Type = types.NewFieldType(mysql.TypeNull) case bool, int64: ve.Type = types.NewFieldType(mysql.TypeLonglong) case uint64: @@ -80,7 +83,7 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type.Charset = "binary" ve.Type.Collate = "binary" default: - panic("illegal literal value type") + panic("illegal literal value type:" + fmt.Sprintf("T", value)) } return ve } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 6f096b8687..643195111a 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -773,13 +773,9 @@ ColumnNameListOpt: { $$ = []*ast.ColumnName{} } -| '(' ')' +| ColumnNameList { - $$ = []*ast.ColumnName{} - } -| '(' ColumnNameList ')' - { - $$ = $2.([]*ast.ColumnName) + $$ = $1.([]*ast.ColumnName) } CommitStmt: @@ -1144,12 +1140,13 @@ DeleteFromStmt: LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - Order: $8.([]*ast.ByItem), } if $7 != nil { x.Where = $7.(ast.ExprNode) } - + if $8 != nil { + x.Order = $8.([]*ast.ByItem) + } if $9 != nil { x.Limit = $9.(*ast.Limit) } @@ -1501,13 +1498,13 @@ Field: } | Expression FieldAsNameOpt { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: $2.(model.CIStr)} + $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: model.NewCIStr($2.(string))} } FieldAsNameOpt: /* EMPTY */ { - $$ = model.CIStr{} + $$ = "" } | FieldAsName { @@ -2213,7 +2210,7 @@ TrimDirection: FunctionCallAgg: "AVG" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt ExpressionList ')' { @@ -2550,14 +2547,32 @@ RollbackStmt: } SelectStmt: - "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt + "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt { - $$ = &ast.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), Fields: $3.(*ast.FieldList), - From: nil, - LockTp: $6.(ast.SelectLockType), + LockTp: $5.(ast.SelectLockType), } + if $4 != nil { + st.Limit = $4.(*ast.Limit) + } + $$ = st + } +| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt + { + st := &ast.SelectStmt { + Distinct: $2.(bool), + Fields: $3.(*ast.FieldList), + LockTp: $7.(ast.SelectLockType), + } + if $5 != nil { + st.Where = $5.(ast.ExprNode) + } + if $6 != nil { + st.Limit = $6.(*ast.Limit) + } + $$ = st } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional @@ -2594,8 +2609,7 @@ SelectStmt: } FromDual: - /* Empty */ -| "FROM" "DUAL" + "FROM" "DUAL" TableRefsClause: @@ -2836,13 +2850,7 @@ UnionStmt: { union := $1.(*ast.UnionStmt) if $2 != nil { - // push union order by into select statements. - orderBy := $2.(*ast.OrderByClause) - cloner := &ast.Cloner{} - for _, s := range union.Selects { - node, _ := orderBy.Accept(cloner) - s.OrderBy = node.(*ast.OrderByClause) - } + union.OrderBy = $2.(*ast.OrderByClause) } if $3 != nil { union.Limit = $3.(*ast.Limit) @@ -3052,12 +3060,15 @@ ShowStmt: } | "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowTables, DBName: $4.(string), Full: $2.(bool), - Where: $5.(ast.ExprNode), } + if $5 != nil { + stmt.Where = $5.(ast.ExprNode) + } + $$ = stmt } | "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt { @@ -3074,18 +3085,24 @@ ShowStmt: } | "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowVariables, GlobalScope: $2.(bool), - Where: $4.(ast.ExprNode), } + if $4 != nil { + stmt.Where = $4.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "COLLATION" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowCollation, - Where: $3.(ast.ExprNode), } + if $3 != nil { + stmt.Where = $3.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "CREATE" "TABLE" TableName { @@ -3174,6 +3191,7 @@ Statement: | PreparedStmt | RollbackStmt | SelectStmt +| UnionStmt | SetStmt | ShowStmt | TruncateTableStmt @@ -3807,13 +3825,11 @@ StringName: * See: https://dev.mysql.com/doc/refman/5.7/en/update.html ***********************************************************************************/ UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { - // Single-table syntax - join := &ast.Join{Left: &ast.TableSource{Source:$4.(ast.ResultSetNode)}, Right: nil} st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: join}, + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, List: $6.([]*ast.Assignment), } if $7 != nil { @@ -3830,23 +3846,6 @@ UpdateStmt: break } } -| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional - { - // Multiple-table syntax - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, - List: $6.([]*ast.Assignment), - MultipleTable: true, - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - $$ = st - if yylex.(*lexer).root { - break - } - } UseStmt: "USE" DBName diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index a23a68a72a..cdbe765347 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/stringutil" ) type lexer struct { @@ -46,6 +47,7 @@ type lexer struct { val []byte ungetBuf []byte root bool + prepare bool stmtStartPos int stringLit []byte @@ -86,6 +88,14 @@ func (l *lexer) SetInj(inj int) { l.inj = inj } +func (l *lexer) SetPrepare() { + l.prepare = true +} + +func (l *lexer) IsPrepare() bool { + return l.prepare +} + func (l *lexer) Root() bool { return l.root } @@ -1001,7 +1011,8 @@ func (l *lexer) str(lval *yySymType, pref string) int { s = strings.TrimSuffix(s, "'") + "\"" pref = "\"" } - v, err := strconv.Unquote(pref + s) + v := stringutil.RemoveUselessBackslash(pref+s) + v, err := strconv.Unquote(v) if err != nil { v = strings.TrimSuffix(s, pref) } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 41c78a815e..7a62c1ee96 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -166,14 +166,14 @@ func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { oldSubquery := &subquery.SubQuery{} switch x := v.Query.(type) { case *ast.SelectStmt: - oldSelect, err := convertSelect(x) + oldSelect, err := convertSelect(c, x) if err != nil { c.err = err return } oldSubquery.Stmt = oldSelect case *ast.UnionStmt: - oldUnion, err := convertUnion(x) + oldUnion, err := convertUnion(c, x) if err != nil { c.err = err return diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 036bfebefd..b58f2d240d 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -16,8 +16,10 @@ package optimizer import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/stmt" @@ -37,17 +39,16 @@ func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expr return oldAssign, nil } -func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { oldInsert := &stmts.InsertIntoStmt{ Priority: v.Priority, Text: v.Text(), } - tableName := v.Table.TableRefs.Left.(*ast.TableName) + tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} for _, val := range v.Columns { oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) } - converter := newExpressionConverter() for _, row := range v.Lists { var oldRow []expression.Expression for _, val := range row { @@ -77,9 +78,9 @@ func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { var err error switch x := v.Select.(type) { case *ast.SelectStmt: - oldInsert.Sel, err = convertSelect(x) + oldInsert.Sel, err = convertSelect(converter, x) case *ast.UnionStmt: - oldInsert.Sel, err = convertUnion(x) + oldInsert.Sel, err = convertUnion(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -88,7 +89,7 @@ func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { return oldInsert, nil } -func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { +func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { oldDelete := &stmts.DeleteStmt{ BeforeFrom: v.BeforeFrom, Ignore: v.Ignore, @@ -97,7 +98,6 @@ func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { Quick: v.Quick, Text: v.Text(), } - converter := newExpressionConverter() oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) if err != nil { return nil, errors.Trace(err) @@ -131,14 +131,13 @@ func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { return oldDelete, nil } -func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { +func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { oldUpdate := &stmts.UpdateStmt{ Ignore: v.Ignore, MultipleTable: v.MultipleTable, LowPriority: v.LowPriority, Text: v.Text(), } - converter := newExpressionConverter() var err error oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) if err != nil { @@ -176,10 +175,7 @@ func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { return oldUpdate, nil } -func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { - converter := &expressionConverter{ - exprMap: map[ast.Node]expression.Expression{}, - } +func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldSelect := &stmts.SelectStmt{} oldSelect.Distinct = s.Distinct oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) @@ -187,9 +183,13 @@ func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldField := &field.Field{} oldField.AsName = val.AsName.O var err error - oldField.Expr, err = convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) + if val.Expr != nil { + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + } else if val.WildCard != nil { + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr("*")} } oldSelect.Fields[i] = oldField } @@ -245,7 +245,7 @@ func convertSelect(s *ast.SelectStmt) (*stmts.SelectStmt, error) { return oldSelect, nil } -func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { +func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { oldUnion := &stmts.UnionStmt{} oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) oldUnion.Distincts = make([]bool, len(u.Selects)-1) @@ -255,12 +255,19 @@ func convertUnion(u *ast.UnionStmt) (*stmts.UnionStmt, error) { } } for i, val := range u.Selects { - oldSelect, err := convertSelect(val) + oldSelect, err := convertSelect(converter, val) if err != nil { return nil, errors.Trace(err) } oldUnion.Selects[i] = oldSelect } + if u.OrderBy != nil { + oldOrderBy, err := convertOrderBy(converter, u.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.OrderBy = oldOrderBy + } if u.Limit != nil { if u.Limit.Offset > 0 { oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} @@ -298,7 +305,7 @@ func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRse oldJoin.Left = oldLeft } - switch r := join.Left.(type) { + switch r := join.Right.(type) { case *ast.Join: oldRight, err := convertJoin(converter, r) if err != nil { @@ -331,13 +338,13 @@ func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*r case *ast.TableName: oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} case *ast.SelectStmt: - oldSelect, err := convertSelect(src) + oldSelect, err := convertSelect(converter, src) if err != nil { return nil, errors.Trace(err) } oldTs.Source = oldSelect case *ast.UnionStmt: - oldUnion, err := convertUnion(src) + oldUnion, err := convertUnion(converter, src) if err != nil { return nil, errors.Trace(err) } @@ -384,7 +391,7 @@ func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) return oldOrderBy, nil } -func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { +func convertCreateDatabase(converter *expressionConverter, v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { oldCreateDatabase := &stmts.CreateDatabaseStmt{ IfNotExists: v.IfNotExists, Name: v.Name, @@ -404,7 +411,7 @@ func convertCreateDatabase(v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt return oldCreateDatabase, nil } -func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { +func convertDropDatabase(converter *expressionConverter, v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { return &stmts.DropDatabaseStmt{ IfExists: v.IfExists, Name: v.Name, @@ -512,12 +519,12 @@ func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*cold return oldConstraint, nil } -func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { +func convertCreateTable(converter *expressionConverter, v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { oldCreateTable := &stmts.CreateTableStmt{ - Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + IfNotExists: v.IfNotExists, + Text: v.Text(), } - converter := newExpressionConverter() for _, val := range v.Cols { oldColDef, err := convertColumnDef(converter, val) if err != nil { @@ -551,7 +558,7 @@ func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) return oldCreateTable, nil } -func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { +func convertDropTable(converter *expressionConverter, v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { oldDropTable := &stmts.DropTableStmt{ IfExists: v.IfExists, Text: v.Text(), @@ -566,7 +573,7 @@ func convertDropTable(v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { return oldDropTable, nil } -func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { +func convertCreateIndex(converter *expressionConverter, v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { oldCreateIndex := &stmts.CreateIndexStmt{ IndexName: v.IndexName, Unique: v.Unique, @@ -587,7 +594,7 @@ func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) return oldCreateIndex, nil } -func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { +func convertDropIndex(converter *expressionConverter, v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { return &stmts.DropIndexStmt{ IfExists: v.IfExists, IndexName: v.IndexName, @@ -595,14 +602,110 @@ func convertDropIndex(v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { }, nil } -func convertAlterTable(v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { +func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec) (*ddl.AlterSpecification, error) { + oldAlterSpec := &ddl.AlterSpecification{ + Name: v.Name, + } + switch v.Tp { + case ast.AlterTableAddConstraint: + oldAlterSpec.Action = ddl.AlterAddConstr + case ast.AlterTableAddColumn: + oldAlterSpec.Action = ddl.AlterAddColumn + case ast.AlterTableDropColumn: + oldAlterSpec.Action = ddl.AlterDropColumn + case ast.AlterTableDropForeignKey: + oldAlterSpec.Action = ddl.AlterDropForeignKey + case ast.AlterTableDropIndex: + oldAlterSpec.Action = ddl.AlterDropIndex + case ast.AlterTableDropPrimaryKey: + oldAlterSpec.Action = ddl.AlterDropPrimaryKey + case ast.AlterTableOption: + oldAlterSpec.Action = ddl.AlterTableOpt + } + if v.Column != nil { + oldColDef, err := convertColumnDef(converter, v.Column) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Column = oldColDef + } + if v.Position != nil { + oldAlterSpec.Position = &ddl.ColumnPosition{} + switch v.Position.Tp { + case ast.ColumnPositionNone: + oldAlterSpec.Position.Type = ddl.ColumnPositionNone + case ast.ColumnPositionFirst: + oldAlterSpec.Position.Type = ddl.ColumnPositionFirst + case ast.ColumnPositionAfter: + oldAlterSpec.Position.Type = ddl.ColumnPositionAfter + } + if v.ColumnName != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.ColumnName) + } + } + if v.Constraint != nil { + oldConstraint, err := convertConstraint(converter, v.Constraint) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Constraint = oldConstraint + } + for _, val := range v.Options { + oldOpt := &coldef.TableOpt{ + StrValue: val.StrValue, + UintValue: val.UintValue, + } + switch val.Tp { + case ast.TableOptionNone: + oldOpt.Tp = coldef.TblOptNone + case ast.TableOptionEngine: + oldOpt.Tp = coldef.TblOptEngine + case ast.TableOptionCharset: + oldOpt.Tp = coldef.TblOptCharset + case ast.TableOptionCollate: + oldOpt.Tp = coldef.TblOptCollate + case ast.TableOptionAutoIncrement: + oldOpt.Tp = coldef.TblOptAutoIncrement + case ast.TableOptionComment: + oldOpt.Tp = coldef.TblOptComment + case ast.TableOptionAvgRowLength: + oldOpt.Tp = coldef.TblOptAvgRowLength + case ast.TableOptionCheckSum: + oldOpt.Tp = coldef.TblOptCheckSum + case ast.TableOptionCompression: + oldOpt.Tp = coldef.TblOptCompression + case ast.TableOptionConnection: + oldOpt.Tp = coldef.TblOptConnection + case ast.TableOptionPassword: + oldOpt.Tp = coldef.TblOptPassword + case ast.TableOptionKeyBlockSize: + oldOpt.Tp = coldef.TblOptKeyBlockSize + case ast.TableOptionMaxRows: + oldOpt.Tp = coldef.TblOptMaxRows + case ast.TableOptionMinRows: + oldOpt.Tp = coldef.TblOptMinRows + } + oldAlterSpec.TableOpts = append(oldAlterSpec.TableOpts, oldOpt) + } + return oldAlterSpec, nil +} + +func convertAlterTable(converter *expressionConverter, v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { oldAlterTable := &stmts.AlterTableStmt{ - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertAlterTableSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterTable.Specs = append(oldAlterTable.Specs, oldSpec) } return oldAlterTable, nil } -func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { +func convertTruncateTable(converter *expressionConverter, v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { return &stmts.TruncateTableStmt{ TableIdent: table.Ident{ Schema: v.Table.Schema, @@ -612,20 +715,20 @@ func convertTruncateTable(v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, e }, nil } -func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { +func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { oldExplain := &stmts.ExplainStmt{ Text: v.Text(), } var err error switch x := v.Stmt.(type) { case *ast.SelectStmt: - oldExplain.S, err = convertSelect(x) + oldExplain.S, err = convertSelect(converter, x) case *ast.UpdateStmt: - oldExplain.S, err = convertUpdate(x) + oldExplain.S, err = convertUpdate(converter, x) case *ast.DeleteStmt: - oldExplain.S, err = convertDelete(x) + oldExplain.S, err = convertDelete(converter, x) case *ast.InsertStmt: - oldExplain.S, err = convertInsert(x) + oldExplain.S, err = convertInsert(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -633,7 +736,7 @@ func convertExplain(v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { return oldExplain, nil } -func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { +func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { oldPrepare := &stmts.PreparedStmt{ InPrepare: true, Name: v.Name, @@ -651,7 +754,7 @@ func convertPrepare(v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { return oldPrepare, nil } -func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { +func convertDeallocate(converter *expressionConverter, v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { return &stmts.DeallocateStmt{ ID: v.ID, Name: v.Name, @@ -659,14 +762,13 @@ func convertDeallocate(v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { }, nil } -func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { +func convertExecute(converter *expressionConverter, v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { oldExec := &stmts.ExecuteStmt{ ID: v.ID, Name: v.Name, Text: v.Text(), } oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) - converter := newExpressionConverter() for i, val := range v.UsingVars { oldVar, err := convertExpr(converter, val) if err != nil { @@ -677,7 +779,7 @@ func convertExecute(v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { return oldExec, nil } -func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { +func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowStmt, error) { oldShow := &stmts.ShowStmt{ DBName: v.DBName, Flag: v.Flag, @@ -735,25 +837,25 @@ func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { return oldShow, nil } -func convertBegin(v *ast.BeginStmt) (*stmts.BeginStmt, error) { +func convertBegin(converter *expressionConverter, v *ast.BeginStmt) (*stmts.BeginStmt, error) { return &stmts.BeginStmt{ Text: v.Text(), }, nil } -func convertCommit(v *ast.CommitStmt) (*stmts.CommitStmt, error) { +func convertCommit(converter *expressionConverter, v *ast.CommitStmt) (*stmts.CommitStmt, error) { return &stmts.CommitStmt{ Text: v.Text(), }, nil } -func convertRollback(v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { +func convertRollback(converter *expressionConverter, v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { return &stmts.RollbackStmt{ Text: v.Text(), }, nil } -func convertUse(v *ast.UseStmt) (*stmts.UseStmt, error) { +func convertUse(converter *expressionConverter, v *ast.UseStmt) (*stmts.UseStmt, error) { return &stmts.UseStmt{ DBName: v.DBName, Text: v.Text(), @@ -775,12 +877,11 @@ func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAs }, nil } -func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { +func convertSet(converter *expressionConverter, v *ast.SetStmt) (*stmts.SetStmt, error) { oldSet := &stmts.SetStmt{ Text: v.Text(), Variables: make([]*stmts.VariableAssignment, len(v.Variables)), } - converter := newExpressionConverter() for i, val := range v.Variables { oldAssign, err := convertVariableAssignment(converter, val) if err != nil { @@ -791,7 +892,7 @@ func convertSet(v *ast.SetStmt) (*stmts.SetStmt, error) { return oldSet, nil } -func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { +func convertSetCharset(converter *expressionConverter, v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { return &stmts.SetCharsetStmt{ Charset: v.Charset, Collate: v.Collate, @@ -799,7 +900,7 @@ func convertSetCharset(v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { }, nil } -func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { +func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { return &stmts.SetPwdStmt{ User: v.User, Password: v.Password, @@ -807,14 +908,13 @@ func convertSetPwd(v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { }, nil } -func convertDo(v *ast.DoStmt) (*stmts.DoStmt, error) { - exprConverter := newExpressionConverter() +func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { oldDo := &stmts.DoStmt{ Text: v.Text(), Exprs: make([]expression.Expression, len(v.Exprs)), } for i, val := range v.Exprs { - oldExpr, err := convertExpr(exprConverter, val) + oldExpr, err := convertExpr(converter, val) if err != nil { return nil, errors.Trace(err) } diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index cec754d164..c4d3966cd8 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -16,79 +16,95 @@ package optimizer import ( "github.com/juju/errors" "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/stmt" ) +type Compiler struct { + paramMarkers []*expression.ParamMarker +} + // Compile compiles a ast.Node into a executable statement. -func Compile(node ast.Node) (stmt.Statement, error) { +func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { validator := &validator{} if _, ok := node.Accept(validator); !ok { return nil, errors.Trace(validator.err) } - binder := &InfoBinder{} - if _, ok := node.Accept(validator); !ok { - return nil, errors.Trace(binder.Err) - } + // binder := &InfoBinder{} + // if _, ok := node.Accept(validator); !ok { + // return nil, errors.Trace(binder.Err) + // } tpComputer := &typeComputer{} if _, ok := node.Accept(tpComputer); !ok { return nil, errors.Trace(tpComputer.err) } - + c := newExpressionConverter() + defer func() { + for _, v := range c.exprMap { + if x, ok := v.(*expression.ParamMarker); ok { + com.paramMarkers = append(com.paramMarkers, x) + } + } + }() switch v := node.(type) { case *ast.InsertStmt: - return convertInsert(v) + return convertInsert(c, v) case *ast.DeleteStmt: - return convertDelete(v) + return convertDelete(c, v) case *ast.UpdateStmt: - return convertUpdate(v) + return convertUpdate(c, v) case *ast.SelectStmt: - return convertSelect(v) + return convertSelect(c, v) case *ast.UnionStmt: - return convertUnion(v) + return convertUnion(c, v) case *ast.CreateDatabaseStmt: - return convertCreateDatabase(v) + return convertCreateDatabase(c, v) case *ast.DropDatabaseStmt: - return convertDropDatabase(v) + return convertDropDatabase(c, v) case *ast.CreateTableStmt: - return convertCreateTable(v) + return convertCreateTable(c, v) case *ast.DropTableStmt: - return convertDropTable(v) + return convertDropTable(c, v) case *ast.CreateIndexStmt: - return convertCreateIndex(v) + return convertCreateIndex(c, v) case *ast.DropIndexStmt: - return convertDropIndex(v) + return convertDropIndex(c, v) case *ast.AlterTableStmt: - return convertAlterTable(v) + return convertAlterTable(c, v) case *ast.TruncateTableStmt: - return convertTruncateTable(v) + return convertTruncateTable(c, v) case *ast.ExplainStmt: - return convertExplain(v) + return convertExplain(c, v) case *ast.PrepareStmt: - return convertPrepare(v) + return convertPrepare(c, v) case *ast.DeallocateStmt: - return convertDeallocate(v) + return convertDeallocate(c, v) case *ast.ExecuteStmt: - return convertExecute(v) + return convertExecute(c, v) case *ast.ShowStmt: - return convertShow(v) + return convertShow(c, v) case *ast.BeginStmt: - return convertBegin(v) + return convertBegin(c, v) case *ast.CommitStmt: - return convertCommit(v) + return convertCommit(c, v) case *ast.RollbackStmt: - return convertRollback(v) + return convertRollback(c, v) case *ast.UseStmt: - return convertUse(v) + return convertUse(c, v) case *ast.SetStmt: - return convertSet(v) + return convertSet(c, v) case *ast.SetCharsetStmt: - return convertSetCharset(v) + return convertSetCharset(c, v) case *ast.SetPwdStmt: - return convertSetPwd(v) + return convertSetPwd(c, v) case *ast.DoStmt: - return convertDo(v) + return convertDo(c, v) } return nil, nil } + +func (com *Compiler) ParamMarkers() []*expression.ParamMarker { + return com.paramMarkers +} diff --git a/tidb.go b/tidb.go index 55581af782..27fcf5167a 100644 --- a/tidb.go +++ b/tidb.go @@ -26,14 +26,13 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/optimizer" - "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" @@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) { rawStmt := l.Stmts() stmts := make([]stmt.Statement, len(rawStmt)) for i, v := range rawStmt { - if node, ok := v.(ast.Node); ok { - stm, err := optimizer.Compile(node) - if err != nil { - return nil, errors.Trace(err) - } - stmts[i] = stm - } else { - stmts[i] = v.(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, err := compiler.Compile(v) + if err != nil { + return nil, errors.Trace(err) } + stmts[i] = stm } return stmts, nil } @@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express return nil, nil, nil } sm := sms[0] - return sm.(stmt.Statement), l.ParamList, nil + compiler := &optimizer.Compiler{} + statement, err := compiler.Compile(sm) + if err != nil { + return nil, nil, errors.Trace(err) + } + return statement, compiler.ParamMarkers(), nil } func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { diff --git a/tidb_test.go b/tidb_test.go index 0cb22ed701..8249ba21aa 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) { // conflict mustExecSQL(c, se1, "begin") rs, err := exec(c, se1, "select * from t where c1=11 for update") + c.Assert(err, IsNil) _, err = rs.Rows(-1, 0) mustExecSQL(c, se2, "begin") From 8f534190160104c093ed891435ad7eea2688d999 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Fri, 30 Oct 2015 14:39:48 +0800 Subject: [PATCH 06/27] optimizer: fix bugs, pass tests. --- ast/ddl.go | 8 +- ast/expressions.go | 22 ++- ast/misc.go | 100 ++++++++++- ast/parser/parser.y | 230 ++++++++++++++++++++---- ast/parser/scanner.l | 6 + ddl/ddl_test.go | 7 +- optimizer/{optimizer.go => compiler.go} | 8 +- optimizer/convert_stmt.go | 103 ++++++++++- stmt/stmts/drop_test.go | 2 +- stmt/stmts/show_test.go | 2 +- 10 files changed, 426 insertions(+), 62 deletions(-) rename optimizer/{optimizer.go => compiler.go} (90%) diff --git a/ast/ddl.go b/ast/ddl.go index 37a5d99c88..87a5ffba70 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -498,7 +498,7 @@ type AlterTableSpec struct { Constraint *Constraint Options []*TableOption Column *ColumnDef - ColumnName *ColumnName + DropColumn *ColumnName Position *ColumnPosition } @@ -523,12 +523,12 @@ func (nod *AlterTableSpec) Accept(v Visitor) (Node, bool) { } nod.Column = node.(*ColumnDef) } - if nod.ColumnName != nil { - node, ok := nod.ColumnName.Accept(v) + if nod.DropColumn != nil { + node, ok := nod.DropColumn.Accept(v) if !ok { return nod, false } - nod.ColumnName = node.(*ColumnName) + nod.DropColumn = node.(*ColumnName) } if nod.Position != nil { node, ok := nod.Position.Accept(v) diff --git a/ast/expressions.go b/ast/expressions.go index e32b14752a..6c09874750 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -83,7 +83,7 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type.Charset = "binary" ve.Type.Collate = "binary" default: - panic("illegal literal value type:" + fmt.Sprintf("T", value)) + panic(fmt.Sprintf("illegal literal value type:%T", value)) } return ve } @@ -581,16 +581,20 @@ func (nod *PatternLikeExpr) Accept(v Visitor) (Node, bool) { return v.Leave(newNod) } nod = newNod.(*PatternLikeExpr) - node, ok := nod.Expr.Accept(v) - if !ok { - return nod, false + if nod.Expr != nil { + node, ok := nod.Expr.Accept(v) + if !ok { + return nod, false + } + nod.Expr = node.(ExprNode) } - nod.Expr = node.(ExprNode) - node, ok = nod.Pattern.Accept(v) - if !ok { - return nod, false + if nod.Pattern != nil { + node, ok := nod.Pattern.Accept(v) + if !ok { + return nod, false + } + nod.Pattern = node.(ExprNode) } - nod.Pattern = node.(ExprNode) return v.Leave(nod) } diff --git a/ast/misc.go b/ast/misc.go index d269e27c84..05ff97184e 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -13,6 +13,8 @@ package ast +import "github.com/pingcap/tidb/mysql" + var ( _ StmtNode = &ExplainStmt{} _ StmtNode = &PrepareStmt{} @@ -26,7 +28,9 @@ var ( _ StmtNode = &SetStmt{} _ StmtNode = &SetCharsetStmt{} _ StmtNode = &SetPwdStmt{} + _ StmtNode = &CreateUserStmt{} _ StmtNode = &DoStmt{} + _ StmtNode = &GrantStmt{} _ Node = &VariableAssignment{} ) @@ -165,7 +169,7 @@ const ( // ShowStmt is a statement to provide information about databases, tables, columns and so on. // See: https://dev.mysql.com/doc/refman/5.7/en/show.html type ShowStmt struct { - stmtNode + dmlNode Tp ShowStmtType // Databases/Tables/Columns/.... DBName string @@ -379,10 +383,20 @@ type UserSpec struct { // CreateUserStmt creates user account. // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html type CreateUserStmt struct { + stmtNode + IfNotExists bool Specs []*UserSpec +} - Text string +// Accept implements Node Accept interface. +func (nod *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*CreateUserStmt) + return v.Leave(nod) } // DoStmt is the struct for DO statement. @@ -408,3 +422,85 @@ func (nod *DoStmt) Accept(v Visitor) (Node, bool) { } return v.Leave(nod) } + +// PrivElem is the privilege type and optional column list. +type PrivElem struct { + node + Priv mysql.PrivilegeType + Cols []*ColumnName +} + +// Accept implements Node Accept interface. +func (nod *PrivElem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*PrivElem) + for i, val := range nod.Cols { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Cols[i] = node.(*ColumnName) + } + return v.Leave(nod) +} + +// ObjectTypeType is the type for object type. +type ObjectTypeType int + +const ( + // ObjectTypeNone is for empty object type. + ObjectTypeNone ObjectTypeType = iota + // ObjectTypeTable means the following object is a table. + ObjectTypeTable +) + +// GrantLevelType is the type for grant level. +type GrantLevelType int + +const ( + // GrantLevelNone is the dummy const for default value. + GrantLevelNone GrantLevelType = iota + // GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server. + GrantLevelGlobal + // GrantLevelDB means the privileges apply to all objects in a given database. + GrantLevelDB + // GrantLevelTable means the privileges apply to all columns in a given table. + GrantLevelTable +) + +// GrantLevel is used for store the privilege scope. +type GrantLevel struct { + Level GrantLevelType + DBName string + TableName string +} + +// GrantStmt is the struct for GRANT statement. +type GrantStmt struct { + stmtNode + + Privs []*PrivElem + ObjectType ObjectTypeType + Level *GrantLevel + Users []*UserSpec +} + +// Accept implements Node Accept interface. +func (nod *GrantStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*GrantStmt) + for i, val := range nod.Privs { + node, ok := val.Accept(v) + if !ok { + return nod, false + } + nod.Privs[i] = node.(*PrivElem) + } + return v.Leave(nod) +} diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 643195111a..97745a62ad 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -141,6 +141,7 @@ import ( fulltext "FULLTEXT" ge ">=" global "GLOBAL" + grant "GRANT" group "GROUP" groupConcat "GROUP_CONCAT" having "HAVING" @@ -190,6 +191,7 @@ import ( nullIf "NULLIF" offset "OFFSET" on "ON" + option "OPTION" or "OR" order "ORDER" oror "||" @@ -229,6 +231,7 @@ import ( tableKwd "TABLE" tables "TABLES" then "THEN" + to "TO" trailing "TRAILING" transaction "TRANSACTION" trim "TRIM" @@ -403,6 +406,7 @@ import ( FunctionNameConflict "Built-in function call names which are conflict with keywords" FuncDatetimePrec "Function datetime precision" GlobalScope "The scope of variable" + GrantStmt "Grant statement" GroupByClause "GROUP BY clause" HashString "Hashed string" HavingClause "HAVING clause" @@ -430,6 +434,7 @@ import ( NotOpt "optional NOT" NowSym "CURRENT_TIMESTAMP/LOCALTIME/LOCALTIMESTAMP/NOW" NumLiteral "Num/Int/Float/Decimal Literal" + ObjectType "Grant statement object type" OnDuplicateKeyUpdate "ON DUPLICATE KEY UPDATE value list" Operand "operand" OptFull "Full or empty" @@ -448,6 +453,10 @@ import ( PrimaryExpression "primary expression" PrimaryFactor "primary expression factor" Priority "insert statement priority" + PrivElem "Privilege element" + PrivElemList "Privilege element list" + PrivLevel "Privilege scope" + PrivType "Privilege type" ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" RollbackStmt "ROLLBACK statement" @@ -491,7 +500,7 @@ import ( UnionOpt "Union Option(empty/ALL/DISTINCT)" UnionStmt "Union select state ment" UnionClauseList "Union select clause list" - UnionClausePList "Union (select) clause list" + UnionSelect "Union (select) item" UpdateStmt "UPDATE statement" Username "Username" UserSpec "Username and auth option" @@ -633,7 +642,7 @@ AlterTableSpec: { $$ = &ast.AlterTableSpec{ Tp: ast.AlterTableDropColumn, - ColumnName: $3.(*ast.ColumnName), + DropColumn: $3.(*ast.ColumnName), } } | "DROP" "PRIMARY" "KEY" @@ -1253,7 +1262,7 @@ ExplainStmt: { $$ = &ast.ExplainStmt{ Stmt: &ast.ShowStmt{ - Tp: ast.ShowTables, + Tp: ast.ShowColumns, Table: $2.(*ast.TableName), }, } @@ -2842,32 +2851,36 @@ SelectLockOpt: // See: https://dev.mysql.com/doc/refman/5.7/en/union.html UnionStmt: - UnionClauseList - { - $$ = $1.(*ast.UnionStmt) - } -| UnionClausePList OrderByOptional SelectStmtLimit + UnionClauseList "UNION" UnionOpt SelectStmt { union := $1.(*ast.UnionStmt) - if $2 != nil { - union.OrderBy = $2.(*ast.OrderByClause) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union + } +| UnionClauseList "UNION" UnionOpt '(' SelectStmt ')' OrderByOptional SelectStmtLimit + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) + if $7 != nil { + union.OrderBy = $7.(*ast.OrderByClause) } - if $3 != nil { - union.Limit = $3.(*ast.Limit) + if $8 != nil { + union.Limit = $8.(*ast.Limit) } $$ = union } UnionClauseList: - SelectStmt "UNION" UnionOpt SelectStmt + UnionSelect { - selects := []*ast.SelectStmt{$1.(*ast.SelectStmt), $4.(*ast.SelectStmt)} + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt)} $$ = &ast.UnionStmt{ - Distinct: $3.(bool), Selects: selects, } } -| UnionClauseList "UNION" UnionOpt SelectStmt +| UnionClauseList "UNION" UnionOpt UnionSelect { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) @@ -2875,21 +2888,11 @@ UnionClauseList: $$ = union } -UnionClausePList: - '(' SelectStmt ')' "UNION" UnionOpt '(' SelectStmt ')' +UnionSelect: + SelectStmt +| '(' SelectStmt ')' { - selects := []*ast.SelectStmt{$2.(*ast.SelectStmt), $7.(*ast.SelectStmt)} - $$ = &ast.UnionStmt{ - Distinct: $5.(bool), - Selects: selects, - } - } -| UnionClausePList "UNION" UnionOpt '(' SelectStmt ')' - { - union := $1.(*ast.UnionStmt) - union.Distinct = union.Distinct || $3.(bool) - union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) - $$ = union + $$ = $2 } UnionOpt: @@ -3066,7 +3069,11 @@ ShowStmt: Full: $2.(bool), } if $5 != nil { - stmt.Where = $5.(ast.ExprNode) + if x, ok := $5.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { + stmt.Where = $5.(ast.ExprNode) + } } $$ = stmt } @@ -3089,7 +3096,9 @@ ShowStmt: Tp: ast.ShowVariables, GlobalScope: $2.(bool), } - if $4 != nil { + if x, ok := $4.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { stmt.Where = $4.(ast.ExprNode) } $$ = stmt @@ -3099,7 +3108,9 @@ ShowStmt: stmt := &ast.ShowStmt{ Tp: ast.ShowCollation, } - if $3 != nil { + if x, ok := $3.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { stmt.Where = $3.(ast.ExprNode) } $$ = stmt @@ -3187,6 +3198,7 @@ Statement: | DropDatabaseStmt | DropIndexStmt | DropTableStmt +| GrantStmt | InsertIntoStmt | PreparedStmt | RollbackStmt @@ -3895,10 +3907,13 @@ CreateUserStmt: UserSpec: Username AuthOption { - $$ = &ast.UserSpec{ + userSpec := &ast.UserSpec{ User: $1.(string), - AuthOpt: $2.(*ast.AuthOption), } + if $2 != nil { + userSpec.AuthOpt = $2.(*ast.AuthOption) + } + $$ = userSpec } UserSpecList: @@ -3912,7 +3927,9 @@ UserSpecList: } AuthOption: - {} + { + $$ = nil + } | "IDENTIFIED" "BY" AuthString { $$ = &ast.AuthOption { @@ -3922,11 +3939,150 @@ AuthOption: } | "IDENTIFIED" "BY" "PASSWORD" HashString { - $$ = &ast.AuthOption { + $$ = &ast.AuthOption{ HashString: $4.(string), } } HashString: stringLit + +/************************************************************************************* + * Grant statement + * See: https://dev.mysql.com/doc/refman/5.7/en/grant.html + *************************************************************************************/ +GrantStmt: + "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecList + { + $$ = &ast.GrantStmt{ + Privs: $2.([]*ast.PrivElem), + ObjectType: $4.(ast.ObjectTypeType), + Level: $5.(*ast.GrantLevel), + Users: $7.([]*ast.UserSpec), + } + } + +PrivElem: + PrivType + { + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + } + } +| PrivType '(' ColumnNameList ')' + { + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + Cols: $3.([]*ast.ColumnName), + } + } + +PrivElemList: + PrivElem + { + $$ = []*ast.PrivElem{$1.(*ast.PrivElem)} + } +| PrivElemList ',' PrivElem + { + $$ = append($1.([]*ast.PrivElem), $3.(*ast.PrivElem)) + } + +PrivType: + "ALL" + { + $$ = mysql.AllPriv + } +| "ALTER" + { + $$ = mysql.AlterPriv + } +| "CREATE" + { + $$ = mysql.CreatePriv + } +| "CREATE" "USER" + { + $$ = mysql.CreateUserPriv + } +| "DELETE" + { + $$ = mysql.DeletePriv + } +| "DROP" + { + $$ = mysql.DropPriv + } +| "EXECUTE" + { + $$ = mysql.ExecutePriv + } +| "INDEX" + { + $$ = mysql.IndexPriv + } +| "INSERT" + { + $$ = mysql.InsertPriv + } +| "SELECT" + { + $$ = mysql.SelectPriv + } +| "SHOW" "DATABASES" + { + $$ = mysql.ShowDBPriv + } +| "UPDATE" + { + $$ = mysql.UpdatePriv + } +| "GRANT" "OPTION" + { + $$ = mysql.GrantPriv + } + +ObjectType: + { + $$ = ast.ObjectTypeNone + } +| "TABLE" + { + $$ = ast.ObjectTypeTable + } + +PrivLevel: + '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + } + } +| '*' '.' '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelGlobal, + } + } +| Identifier '.' '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + DBName: $1.(string), + } + } +| Identifier '.' Identifier + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, + DBName: $1.(string), + TableName: $3.(string), + } + } +| Identifier + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, + TableName: $1.(string), + } + } %% diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index cdbe765347..21a8680588 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -316,6 +316,7 @@ from {f}{r}{o}{m} full {f}{u}{l}{l} fulltext {f}{u}{l}{l}{t}{e}{x}{t} global {g}{l}{o}{b}{a}{l} +grant {g}{r}{a}{n}{t} group {g}{r}{o}{u}{p} group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} @@ -356,6 +357,7 @@ national {n}{a}{t}{i}{o}{n}{a}{l} not {n}{o}{t} offset {o}{f}{f}{s}{e}{t} on {o}{n} +option {o}{p}{t}{i}{o}{n} or {o}{r} order {o}{r}{d}{e}{r} outer {o}{u}{t}{e}{r} @@ -389,6 +391,7 @@ sysdate {s}{y}{s}{d}{a}{t}{e} table {t}{a}{b}{l}{e} tables {t}{a}{b}{l}{e}{s} then {t}{h}{e}{n} +to {t}{o} trailing {t}{r}{a}{i}{l}{i}{n}{g} transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n} trim {t}{r}{i}{m} @@ -683,6 +686,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {full} lval.item = string(l.val) return full {fulltext} return fulltext +{grant} return grant {group} return group {group_concat} lval.item = string(l.val) return groupConcat @@ -757,6 +761,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {offset} lval.item = string(l.val) return offset {on} return on +{option} return option {order} return order {or} return or {outer} return outer @@ -819,6 +824,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {tables} lval.item = string(l.val) return tables {then} return then +{to} return to {trailing} return trailing {transaction} lval.item = string(l.val) return transaction diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 864567a29f..a1805a0bcd 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -19,12 +19,13 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/optimizer" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/stmt/stmts" @@ -198,5 +199,7 @@ func statement(sql string) stmt.Statement { log.Debug("Compile", sql) lexer := parser.NewLexer(sql) parser.YYParse(lexer) - return lexer.Stmts()[0].(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, _ := compiler.Compile(lexer.Stmts()[0]) + return stm } diff --git a/optimizer/optimizer.go b/optimizer/compiler.go similarity index 90% rename from optimizer/optimizer.go rename to optimizer/compiler.go index c4d3966cd8..bff4623520 100644 --- a/optimizer/optimizer.go +++ b/optimizer/compiler.go @@ -20,11 +20,12 @@ import ( "github.com/pingcap/tidb/stmt" ) +// Compiler compiles ast.Node into an executable statement. type Compiler struct { paramMarkers []*expression.ParamMarker } -// Compile compiles a ast.Node into a executable statement. +// Compile compiles a ast.Node into an executable statement. func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { validator := &validator{} if _, ok := node.Accept(validator); !ok { @@ -99,12 +100,17 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { return convertSetCharset(c, v) case *ast.SetPwdStmt: return convertSetPwd(c, v) + case *ast.CreateUserStmt: + return convertCreateUser(c, v) case *ast.DoStmt: return convertDo(c, v) + case *ast.GrantStmt: + return convertGrant(c, v) } return nil, nil } +// ParamMarkers returns parameter markers for prepared statement. func (com *Compiler) ParamMarkers() []*expression.ParamMarker { return com.paramMarkers } diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index b58f2d240d..3be497443d 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -176,8 +176,10 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up } func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { - oldSelect := &stmts.SelectStmt{} - oldSelect.Distinct = s.Distinct + oldSelect := &stmts.SelectStmt{ + Distinct: s.Distinct, + Text: s.Text(), + } oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) for i, val := range s.Fields.Fields { oldField := &field.Field{} @@ -246,7 +248,9 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se } func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { - oldUnion := &stmts.UnionStmt{} + oldUnion := &stmts.UnionStmt{ + Text: u.Text(), + } oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) oldUnion.Distincts = make([]bool, len(u.Selects)-1) if u.Distinct { @@ -639,10 +643,13 @@ func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec case ast.ColumnPositionAfter: oldAlterSpec.Position.Type = ddl.ColumnPositionAfter } - if v.ColumnName != nil { - oldAlterSpec.Position.RelativeColumn = joinColumnName(v.ColumnName) + if v.Position.RelativeColumn != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.Position.RelativeColumn) } } + if v.DropColumn != nil { + oldAlterSpec.Name = joinColumnName(v.DropColumn) + } if v.Constraint != nil { oldConstraint, err := convertConstraint(converter, v.Constraint) if err != nil { @@ -729,6 +736,8 @@ func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts. oldExplain.S, err = convertDelete(converter, x) case *ast.InsertStmt: oldExplain.S, err = convertInsert(converter, x) + case *ast.ShowStmt: + oldExplain.S, err = convertShow(converter, x) } if err != nil { return nil, errors.Trace(err) @@ -908,6 +917,36 @@ func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.Se }, nil } +func convertUserSpec(converter *expressionConverter, v *ast.UserSpec) (*coldef.UserSpecification, error) { + oldSpec := &coldef.UserSpecification{ + User: v.User, + } + if v.AuthOpt != nil { + oldAuthOpt := &coldef.AuthOption{ + AuthString: v.AuthOpt.AuthString, + ByAuthString: v.AuthOpt.ByAuthString, + HashString: v.AuthOpt.HashString, + } + oldSpec.AuthOpt = oldAuthOpt + } + return oldSpec, nil +} + +func convertCreateUser(converter *expressionConverter, v *ast.CreateUserStmt) (*stmts.CreateUserStmt, error) { + oldCreateUser := &stmts.CreateUserStmt{ + IfNotExists: v.IfNotExists, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateUser.Specs = append(oldCreateUser.Specs, oldSpec) + } + return oldCreateUser, nil +} + func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { oldDo := &stmts.DoStmt{ Text: v.Text(), @@ -922,3 +961,57 @@ func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, er } return oldDo, nil } + +func convertPrivElem(converter *expressionConverter, v *ast.PrivElem) (*coldef.PrivElem, error) { + oldPrim := &coldef.PrivElem{ + Priv: v.Priv, + } + for _, val := range v.Cols { + oldPrim.Cols = append(oldPrim.Cols, joinColumnName(val)) + } + return oldPrim, nil +} + +func convertGrant(converter *expressionConverter, v *ast.GrantStmt) (*stmts.GrantStmt, error) { + oldGrant := &stmts.GrantStmt{ + Text: v.Text(), + } + for _, val := range v.Privs { + oldPrim, err := convertPrivElem(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Privs = append(oldGrant.Privs, oldPrim) + } + switch v.ObjectType { + case ast.ObjectTypeNone: + oldGrant.ObjectType = coldef.ObjectTypeNone + case ast.ObjectTypeTable: + oldGrant.ObjectType = coldef.ObjectTypeTable + } + if v.Level != nil { + oldGrantLevel := &coldef.GrantLevel{ + DBName: v.Level.DBName, + TableName: v.Level.TableName, + } + switch v.Level.Level { + case ast.GrantLevelDB: + oldGrantLevel.Level = coldef.GrantLevelDB + case ast.GrantLevelGlobal: + oldGrantLevel.Level = coldef.GrantLevelGlobal + case ast.GrantLevelNone: + oldGrantLevel.Level = coldef.GrantLevelNone + case ast.GrantLevelTable: + oldGrantLevel.Level = coldef.GrantLevelTable + } + oldGrant.Level = oldGrantLevel + } + for _, val := range v.Users { + oldUserSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Users = append(oldGrant.Users, oldUserSpec) + } + return oldGrant, nil +} diff --git a/stmt/stmts/drop_test.go b/stmt/stmts/drop_test.go index 5d0a34ec0d..4a1de810c2 100644 --- a/stmt/stmts/drop_test.go +++ b/stmt/stmts/drop_test.go @@ -62,7 +62,7 @@ func (s *testStmtSuite) TestDropTable(c *C) { } func (s *testStmtSuite) TestDropIndex(c *C) { - testSQL := "drop index if exists drop_index;" + testSQL := "drop index if exists drop_index on t;" stmtList, err := tidb.Compile(s.ctx, testSQL) c.Assert(err, IsNil) diff --git a/stmt/stmts/show_test.go b/stmt/stmts/show_test.go index 5dc0c57828..e67813a46d 100644 --- a/stmt/stmts/show_test.go +++ b/stmt/stmts/show_test.go @@ -51,5 +51,5 @@ func (s *testStmtSuite) TestShow(c *C) { c.Assert(stmtList, HasLen, 1) testStmt, ok = stmtList[0].(*stmts.ShowStmt) c.Assert(ok, IsTrue) - c.Assert(testStmt.Pattern, NotNil) + c.Assert(testStmt.Pattern, NotNil, Commentf("S: %s", testStmt.Text)) } From 05b7eefa03a0b10315a2a1205c77439270fd0b58 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Fri, 30 Oct 2015 20:51:36 +0800 Subject: [PATCH 07/27] ast: update after rebase. --- ast/dml.go | 1 + ast/functions.go | 10 +++-- ast/misc.go | 2 + ast/parser/parser.y | 80 ++++++++++++++++++++++++++++++++++----- ast/parser/scanner.l | 19 ++++++++++ optimizer/convert_stmt.go | 55 ++++++++++++++++----------- 6 files changed, 130 insertions(+), 37 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index b19e60bfc6..afd836ef66 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -566,6 +566,7 @@ const ( type InsertStmt struct { dmlNode + Replace bool Table *TableRefsClause Columns []*ColumnName Lists [][]ExprNode diff --git a/ast/functions.go b/ast/functions.go index 97b9cd1e3c..78ab444236 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -200,11 +200,13 @@ func (nod *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { return nod, false } nod.Pos = node.(ExprNode) - node, ok = nod.Len.Accept(v) - if !ok { - return nod, false + if nod.Len != nil { + node, ok = nod.Len.Accept(v) + if !ok { + return nod, false + } + nod.Len = node.(ExprNode) } - nod.Len = node.(ExprNode) return v.Leave(nod) } diff --git a/ast/misc.go b/ast/misc.go index 05ff97184e..b0f9e845fc 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -164,6 +164,7 @@ const ( ShowVariables ShowCollation ShowCreateTable + ShowGrants ) // ShowStmt is a statement to provide information about databases, tables, columns and so on. @@ -177,6 +178,7 @@ type ShowStmt struct { Column *ColumnName // Used for `desc table column`. Flag int // Some flag parsed from sql, such as FULL. Full bool + User string // Used for show grants. // Used by show variables GlobalScope bool diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 97745a62ad..9c02a8e4c5 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -142,6 +142,7 @@ import ( ge ">=" global "GLOBAL" grant "GRANT" + grants "GRANTS" group "GROUP" groupConcat "GROUP_CONCAT" having "HAVING" @@ -206,6 +207,7 @@ import ( references "REFERENCES" regexp "REGEXP" repeat "REPEAT" + replace "REPLACE" right "RIGHT" rlike "RLIKE" rollback "ROLLBACK" @@ -418,7 +420,7 @@ import ( IndexName "index name" IndexType "index type" InsertIntoStmt "INSERT INTO statement" - InsertRest "Rest part of INSERT INTO statement" + InsertValues "Rest part of INSERT/REPLACE INTO statement" IntoOpt "INTO or EmptyString" JoinTable "join table" JoinType "join type" @@ -459,6 +461,8 @@ import ( PrivType "Privilege type" ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" + ReplaceIntoStmt "REPLACE INTO statement" + ReplacePriority "replace statement priority" RollbackStmt "ROLLBACK statement" SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," SelectStmt "SELECT statement" @@ -1224,14 +1228,14 @@ DropIndexStmt: } DropTableStmt: - "DROP" "TABLE" TableNameList + "DROP" TableOrTables TableNameList { $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} if yylex.(*lexer).root { break } } -| "DROP" "TABLE" "IF" "EXISTS" TableNameList +| "DROP" TableOrTables "IF" "EXISTS" TableNameList { $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} if yylex.(*lexer).root { @@ -1239,6 +1243,10 @@ DropTableStmt: } } +TableOrTables: + "TABLE" +| "TABLES" + EqOpt: { } @@ -1619,7 +1627,7 @@ UnReservedKeyword: | "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" | "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" | "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" +| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" NotKeywordToken: "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" @@ -1633,7 +1641,7 @@ NotKeywordToken: * TODO: support PARTITION **********************************************************************************/ InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableName InsertRest OnDuplicateKeyUpdate + "INSERT" Priority IgnoreOptional IntoOpt TableName InsertValues OnDuplicateKeyUpdate { x := $6.(*ast.InsertStmt) x.Priority = $2.(int) @@ -1656,7 +1664,7 @@ IntoOpt: { } -InsertRest: +InsertValues: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { $$ = &ast.InsertStmt{ @@ -1741,7 +1749,39 @@ OnDuplicateKeyUpdate: $$ = $5 } -/***********************************Insert Statments END************************************/ +/***********************************Insert Statements END************************************/ + +/************************************************************************************ + * Replace Statements + * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html + * + * TODO: support PARTITION + **********************************************************************************/ +ReplaceIntoStmt: + "REPLACE" ReplacePriority IntoOpt TableName InsertValues + { + x := $5.(*ast.InsertStmt) + x.Replace = true + x.Priority = $2.(int) + ts := &ast.TableSource{Source: $4.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} + $$ = x + } + +ReplacePriority: + { + $$ = ast.NoPriority + } +| "LOW_PRIORITY" + { + $$ = ast.LowPriority + } +| "DELAYED" + { + $$ = ast.DelayedPriority + } + +/***********************************Replace Statments END************************************/ Literal: "false" @@ -1984,7 +2024,7 @@ FunctionCallKeyword: } | "YEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: @@ -2049,7 +2089,7 @@ FunctionCallNonKeyword: } | "IFNULL" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { @@ -2107,6 +2147,11 @@ FunctionCallNonKeyword: } $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} } +| "REPLACE" '(' Expression ',' Expression ',' Expression ')' + { + args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} + $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + } | "SECOND" '(' Expression ')' { $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} @@ -3118,10 +3163,23 @@ ShowStmt: | "SHOW" "CREATE" "TABLE" TableName { $$ = &ast.ShowStmt{ - Tp: ast.ShowCreateTable, + Tp: ast.ShowCreateTable, Table: $4.(*ast.TableName), } } +| "SHOW" "GRANTS" + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &ast.ShowStmt{Tp: ast.ShowGrants} + } +| "SHOW" "GRANTS" "FOR" Username + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &ast.ShowStmt{ + Tp: ast.ShowGrants, + User: $4.(string), + } + } ShowLikeOrWhereOpt: { @@ -3202,6 +3260,7 @@ Statement: | InsertIntoStmt | PreparedStmt | RollbackStmt +| ReplaceIntoStmt | SelectStmt | UnionStmt | SetStmt @@ -3221,6 +3280,7 @@ ExplainableStmt: | DeleteFromStmt | UpdateStmt | InsertIntoStmt +| ReplaceIntoStmt StatementList: Statement diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index 21a8680588..048817d1b9 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -54,6 +54,10 @@ type lexer struct { // record token's offset of the input tokenEndOffset int tokenStartOffset int + + // Charset information + charset string + collation string } @@ -104,6 +108,15 @@ func (l *lexer) SetRoot(root bool) { l.root = root } +func (l *lexer) SetCharsetInfo(charset, collation string) { + l.charset = charset + l.collation = collation +} + +func (l *lexer) GetCharsetInfo() (string, string) { + return l.charset, l.collation +} + func (l *lexer) unget(b byte) { l.ungetBuf = append(l.ungetBuf, b) l.i-- @@ -317,6 +330,7 @@ full {f}{u}{l}{l} fulltext {f}{u}{l}{l}{t}{e}{x}{t} global {g}{l}{o}{b}{a}{l} grant {g}{r}{a}{n}{t} +grants {g}{r}{a}{n}{t}{s} group {g}{r}{o}{u}{p} group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} @@ -370,6 +384,7 @@ rand {r}{a}{n}{d} repeat {r}{e}{p}{e}{a}{t} references {r}{e}{f}{e}{r}{e}{n}{c}{e}{s} regexp {r}{e}{g}{e}{x}{p} +replace {r}{e}{p}{l}{a}{c}{e} right {r}{i}{g}{h}{t} rlike {r}{l}{i}{k}{e} rollback {r}{o}{l}{l}{b}{a}{c}{k} @@ -687,6 +702,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return full {fulltext} return fulltext {grant} return grant +{grants} lval.item = string(l.val) + return grants {group} return group {group_concat} lval.item = string(l.val) return groupConcat @@ -795,6 +812,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {repeat} lval.item = string(l.val) return repeat {regexp} return regexp +{replace} lval.item = string(l.val) + return replace {references} return references {rlike} return rlike diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 3be497443d..ff026f6591 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -39,15 +39,13 @@ func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expr return oldAssign, nil } -func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { - oldInsert := &stmts.InsertIntoStmt{ - Priority: v.Priority, - Text: v.Text(), - } +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (stmt.Statement, error) { + insertValues := stmts.InsertValues{} + insertValues.Priority = v.Priority tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) - oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + insertValues.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} for _, val := range v.Columns { - oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) + insertValues.ColNames = append(insertValues.ColNames, joinColumnName(val)) } for _, row := range v.Lists { var oldRow []expression.Expression @@ -58,33 +56,44 @@ func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (*stmts.In } oldRow = append(oldRow, oldExpr) } - oldInsert.Lists = append(oldInsert.Lists, oldRow) + insertValues.Lists = append(insertValues.Lists, oldRow) } for _, assign := range v.Setlist { oldAssign, err := convertAssignment(converter, assign) if err != nil { return nil, errors.Trace(err) } - oldInsert.Setlist = append(oldInsert.Setlist, oldAssign) + insertValues.Setlist = append(insertValues.Setlist, oldAssign) + } + + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + insertValues.Sel, err = convertSelect(converter, x) + case *ast.UnionStmt: + insertValues.Sel, err = convertUnion(converter, x) + } + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Replace { + return &stmts.ReplaceIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), + }, nil + } + oldInsert := &stmts.InsertIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), } for _, onDup := range v.OnDuplicate { oldOnDup, err := convertAssignment(converter, onDup) if err != nil { return nil, errors.Trace(err) } - oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, *oldOnDup) - } - if v.Select != nil { - var err error - switch x := v.Select.(type) { - case *ast.SelectStmt: - oldInsert.Sel, err = convertSelect(converter, x) - case *ast.UnionStmt: - oldInsert.Sel, err = convertUnion(converter, x) - } - if err != nil { - return nil, errors.Trace(err) - } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, oldOnDup) } return oldInsert, nil } @@ -155,7 +164,7 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up if err != nil { return nil, errors.Trace(err) } - oldUpdate.List = append(oldUpdate.List, *oldAssign) + oldUpdate.List = append(oldUpdate.List, oldAssign) } if v.Order != nil { orderRset := &rsets.OrderByRset{} From 92db5158c489d60608bd057ad6b065f79fc90e68 Mon Sep 17 00:00:00 2001 From: ngaut Date: Sat, 31 Oct 2015 13:18:38 +0800 Subject: [PATCH 08/27] *: Tiny refactor --- ast/cloner.go | 2 +- ast/functions.go | 4 ++-- optimizer/convert_expr.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ast/cloner.go b/ast/cloner.go index bf7b3c961d..be5d7e441e 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -15,7 +15,7 @@ package ast import "fmt" -// Cloner is a ast visitor that clones a node. +// Cloner is an ast visitor that clones a node. type Cloner struct { } diff --git a/ast/functions.go b/ast/functions.go index 78ab444236..56bf663eeb 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -35,7 +35,7 @@ var ( type FuncCallExpr struct { funcNode // F is the function name. - F string + FnName string // Args is the function args. Args []ExprNode } @@ -59,7 +59,7 @@ func (nod *FuncCallExpr) Accept(v Visitor) (Node, bool) { // IsStatic implements the ExprNode IsStatic interface. func (nod *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(nod.F)] + v := builtin.Funcs[strings.ToLower(nod.FnName)] if v.F == nil || !v.IsStatic { return false } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 7a62c1ee96..ac57e91e77 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -318,7 +318,7 @@ func (c *expressionConverter) variable(v *ast.VariableExpr) { func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) { oldCall := &expression.Call{ - F: v.F, + F: v.FnName, } oldCall.Args = make([]expression.Expression, len(v.Args)) for i, val := range v.Args { From 1b9fbda6b6192de1a4427b1fd0ae6f200f74616d Mon Sep 17 00:00:00 2001 From: ngaut Date: Sat, 31 Oct 2015 13:22:56 +0800 Subject: [PATCH 09/27] ast/parser: Fix parser.y --- ast/parser/parser.y | 70 ++++++++++++++++++++++----------------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 9c02a8e4c5..4716d418ff 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1938,16 +1938,16 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_DATE" { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } DistinctOpt: @@ -2007,11 +2007,11 @@ FunctionCallKeyword: } | "DATE" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "USER" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { @@ -2020,21 +2020,21 @@ FunctionCallKeyword: } | "WEEK" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "YEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: "COALESCE" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURDATE" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_TIMESTAMP" FuncDatetimePrec { @@ -2042,35 +2042,35 @@ FunctionCallNonKeyword: if $2 != nil { args = append(args, $2.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "ABS" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "CONCAT" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CONCAT_WS" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "DAY" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFWEEK" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFMONTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFYEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { @@ -2081,19 +2081,19 @@ FunctionCallNonKeyword: } | "FOUND_ROWS" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "HOUR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "IFNULL" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "LOCATE" '(' Expression ',' Expression ')' { @@ -2112,19 +2112,19 @@ FunctionCallNonKeyword: } | "LOWER" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MICROSECOND" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MINUTE" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MONTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "NOW" '(' ExpressionOpt ')' { @@ -2132,11 +2132,11 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "NULLIF" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "RAND" '(' ExpressionOpt ')' { @@ -2145,16 +2145,16 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "REPLACE" '(' Expression ',' Expression ',' Expression ')' { args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "SECOND" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "SUBSTRING" '(' Expression ',' Expression ')' { @@ -2200,7 +2200,7 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "TRIM" '(' Expression ')' { @@ -2232,19 +2232,19 @@ FunctionCallNonKeyword: } | "UPPER" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKDAY" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKOFYEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "YEARWEEK" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } TrimDirection: From 01dac873baa102c4a8ae9a99b7253d03f9044c90 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Sat, 31 Oct 2015 22:37:27 +0800 Subject: [PATCH 10/27] *: Address ci problem --- ast/expressions.go | 4 +++- ast/parser/parser.y | 18 ++++++++++++++++++ ast/parser/scanner.l | 19 ++++++++++++++++++- 3 files changed, 39 insertions(+), 2 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index 6c09874750..7c82ed2d2a 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -56,7 +56,7 @@ type ValueExpr struct { // NewValueExpr creates a ValueExpr with value, and sets default field type. func NewValueExpr(value interface{}) *ValueExpr { ve := &ValueExpr{} - ve.Data = value + ve.Data = types.RawData(value) // TODO: make it more precise. switch value.(type) { case nil: @@ -82,6 +82,8 @@ func NewValueExpr(value interface{}) *ValueExpr { ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = "binary" ve.Type.Collate = "binary" + case *types.DataItem: + ve.Type = value.(*types.DataItem).Type default: panic(fmt.Sprintf("illegal literal value type:%T", value)) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 9c02a8e4c5..127a3b965d 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -239,6 +239,7 @@ import ( trim "TRIM" trueKwd "true" truncate "TRUNCATE" + underscoreCS "UNDERSCORE_CHARSET" unknown "UNKNOWN" union "UNION" unique "UNIQUE" @@ -1796,6 +1797,23 @@ Literal: | floatLit | intLit | stringLit +| "UNDERSCORE_CHARSET" stringLit + { + // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + tp := types.NewFieldType(mysql.TypeString) + tp.Charset = $1.(string) + co, err := charset.GetDefaultCollation(tp.Charset) + if err != nil { + l := yylex.(*lexer) + l.errf("Get collation error for charset: %s", tp.Charset) + return 1 + } + tp.Collate = co + $$ = &types.DataItem{ + Type: tp, + Data: $2.(string), + } + } | hexLit | bitLit diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index 048817d1b9..adaa662651 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -27,6 +27,7 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/stringutil" ) @@ -1010,7 +1011,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return integerType {ident} lval.item = string(l.val) - return identifier + return l.handleIdent(lval) . return c0 @@ -1101,3 +1102,19 @@ func (l *lexer) bit(lval *yySymType) int { lval.item = b return bitLit } + +func (l *lexer) handleIdent(lval *yySymType) int { + s := lval.item.(string) + // A character string literal may have an optional character set introducer and COLLATE clause: + // [_charset_name]'string' [COLLATE collation_name] + // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + if !strings.HasPrefix(s, "_") { + return identifier + } + cs, _, err := charset.GetCharsetInfo(s[1:]) + if err != nil { + return identifier + } + lval.item = cs + return underscoreCS +} From f3016ff56c1371cf8e295b4246a9371c68af2004 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Sat, 31 Oct 2015 23:07:31 +0800 Subject: [PATCH 11/27] optimizer: fix parameter marker out of range --- optimizer/compiler.go | 6 +----- optimizer/convert_expr.go | 9 ++++++--- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/optimizer/compiler.go b/optimizer/compiler.go index bff4623520..1fc875a206 100644 --- a/optimizer/compiler.go +++ b/optimizer/compiler.go @@ -43,11 +43,7 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { } c := newExpressionConverter() defer func() { - for _, v := range c.exprMap { - if x, ok := v.(*expression.ParamMarker); ok { - com.paramMarkers = append(com.paramMarkers, x) - } - } + com.paramMarkers = c.paramMarkers }() switch v := node.(type) { case *ast.InsertStmt: diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index ac57e91e77..52dfbe8859 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -33,8 +33,9 @@ func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression. // expressionConverter converts ast expression to // old expression for transition state. type expressionConverter struct { - exprMap map[ast.Node]expression.Expression - err error + exprMap map[ast.Node]expression.Expression + paramMarkers []*expression.ParamMarker + err error } func newExpressionConverter() *expressionConverter { @@ -265,7 +266,9 @@ func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { } func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { - c.exprMap[v] = &expression.ParamMarker{} + marker := &expression.ParamMarker{} + c.exprMap[v] = marker + c.paramMarkers = append(c.paramMarkers, marker) } func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { From eed9647417dd0148c7b648faa1675bda8cf06939 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 2 Nov 2015 11:34:13 +0800 Subject: [PATCH 12/27] ast, optimizer: add Offset for ParmMarker to make sure the order of param markers. --- ast/expressions.go | 1 + ast/parser/parser.y | 4 +++- optimizer/compiler.go | 30 +++++++++++++++++++++++++----- optimizer/convert_expr.go | 7 +++---- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index 7c82ed2d2a..fe0d311b64 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -609,6 +609,7 @@ func (nod *PatternLikeExpr) IsStatic() bool { // Used in parsing prepare statement. type ParamMarkerExpr struct { exprNode + Offset int } // Accept implements Node Accept interface. diff --git a/ast/parser/parser.y b/ast/parser/parser.y index a32c488fdf..33a4f9f1be 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1844,7 +1844,9 @@ Operand: } | "PLACEHOLDER" { - $$ = &ast.ParamMarkerExpr{} + $$ = &ast.ParamMarkerExpr{ + Offset: yyS[yypt].offset, + } } | "ROW" '(' Expression ',' ExpressionList ')' { diff --git a/optimizer/compiler.go b/optimizer/compiler.go index 1fc875a206..a9e42cd8b0 100644 --- a/optimizer/compiler.go +++ b/optimizer/compiler.go @@ -14,6 +14,8 @@ package optimizer import ( + "sort" + "github.com/juju/errors" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/expression" @@ -22,7 +24,7 @@ import ( // Compiler compiles ast.Node into an executable statement. type Compiler struct { - paramMarkers []*expression.ParamMarker + converter *expressionConverter } // Compile compiles a ast.Node into an executable statement. @@ -42,9 +44,7 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { return nil, errors.Trace(tpComputer.err) } c := newExpressionConverter() - defer func() { - com.paramMarkers = c.paramMarkers - }() + com.converter = c switch v := node.(type) { case *ast.InsertStmt: return convertInsert(c, v) @@ -106,7 +106,27 @@ func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { return nil, nil } +type paramMarkers []*ast.ParamMarkerExpr + +func (p paramMarkers) Len() int { + return len(p) +} + +func (p paramMarkers) Less(i, j int) bool { + return p[i].Offset < p[j].Offset +} + +func (p paramMarkers) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + // ParamMarkers returns parameter markers for prepared statement. func (com *Compiler) ParamMarkers() []*expression.ParamMarker { - return com.paramMarkers + c := com.converter + sort.Sort(c.paramMarkers) + oldMarkers := make([]*expression.ParamMarker, len(c.paramMarkers)) + for i, val := range c.paramMarkers { + oldMarkers[i] = c.exprMap[val].(*expression.ParamMarker) + } + return oldMarkers } diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 52dfbe8859..e8aa1bce84 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -34,7 +34,7 @@ func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression. // old expression for transition state. type expressionConverter struct { exprMap map[ast.Node]expression.Expression - paramMarkers []*expression.ParamMarker + paramMarkers paramMarkers err error } @@ -266,9 +266,8 @@ func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { } func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { - marker := &expression.ParamMarker{} - c.exprMap[v] = marker - c.paramMarkers = append(c.paramMarkers, marker) + c.paramMarkers = append(c.paramMarkers, v) + c.exprMap[v] = &expression.ParamMarker{} } func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { From ab2ec90d93a3468291c21d057fd9076414764eb3 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 2 Nov 2015 13:34:04 +0800 Subject: [PATCH 13/27] ast/parser, optimizer: fix duplicate parameter marker, fix offset. --- ast/parser/parser.y | 2 +- optimizer/convert_expr.go | 6 ++++-- optimizer/convert_stmt.go | 3 --- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 33a4f9f1be..90de8fad3c 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -2833,7 +2833,7 @@ SelectStmtLimit: } | "LIMIT" LengthNum "OFFSET" LengthNum { - $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} + $$ = &ast.Limit{Offset: $4.(uint64), Count: $2.(uint64)} } SelectStmtDistinct: diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index e8aa1bce84..d43f93f016 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -266,8 +266,10 @@ func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { } func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { - c.paramMarkers = append(c.paramMarkers, v) - c.exprMap[v] = &expression.ParamMarker{} + if c.exprMap[v] == nil { + c.exprMap[v] = &expression.ParamMarker{} + c.paramMarkers = append(c.paramMarkers, v) + } } func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index ff026f6591..535dfcddea 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -762,7 +762,6 @@ func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts. Text: v.Text(), } if v.SQLVar != nil { - converter := newExpressionConverter() oldSQLVar, err := convertExpr(converter, v.SQLVar) if err != nil { return nil, errors.Trace(err) @@ -815,7 +814,6 @@ func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowSt oldShow.ColumnName = joinColumnName(v.Column) } if v.Where != nil { - converter := newExpressionConverter() oldWhere, err := convertExpr(converter, v.Where) if err != nil { return nil, errors.Trace(err) @@ -823,7 +821,6 @@ func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowSt oldShow.Where = oldWhere } if v.Pattern != nil { - converter := newExpressionConverter() oldPattern, err := convertExpr(converter, v.Pattern) if err != nil { return nil, errors.Trace(err) From 067df3ab3438ef0aaa42097743b808884751ebd5 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Mon, 2 Nov 2015 18:37:49 +0800 Subject: [PATCH 14/27] optimizer: fix wildcard conversion. --- ast/expressions.go | 2 +- ast/functions.go | 3 +++ ast/parser/parser.y | 5 +++-- optimizer/convert_stmt.go | 9 ++++++++- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/ast/expressions.go b/ast/expressions.go index fe0d311b64..cddb7f45d5 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -66,7 +66,7 @@ func NewValueExpr(value interface{}) *ValueExpr { case uint64: ve.Type = types.NewFieldType(mysql.TypeLonglong) ve.Type.Flag |= mysql.UnsignedFlag - case string: + case string, UnquoteString: ve.Type = types.NewFieldType(mysql.TypeVarchar) ve.Type.Charset = mysql.DefaultCharset ve.Type.Collate = mysql.DefaultCollationName diff --git a/ast/functions.go b/ast/functions.go index 56bf663eeb..e3323ef3d1 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -31,6 +31,9 @@ var ( _ FuncNode = &AggregateFuncExpr{} ) +// UnquoteString is not quoted when printed. +type UnquoteString string + // FuncCallExpr is for function expression. type FuncCallExpr struct { funcNode diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 90de8fad3c..a35ea0d086 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -2036,7 +2036,7 @@ FunctionCallKeyword: | "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { // TODO: support qualified identifier for column_name - $$ = &ast.ColumnNameExpr{Name: $3.(*ast.ColumnName)} + $$ = &ast.ValuesExpr{Column: $3.(*ast.ColumnName)} } | "WEEK" '(' ExpressionList ')' { @@ -2292,7 +2292,7 @@ FunctionCallAgg: } | "COUNT" '(' DistinctOpt '*' ')' { - args := []ast.ExprNode{ast.NewValueExpr("*")} + args := []ast.ExprNode{ast.NewValueExpr(ast.UnquoteString("*"))} $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' @@ -2753,6 +2753,7 @@ TableFactor: TableAsNameOpt: { $$ = model.CIStr{} + yyS[yypt].offset = yylex.(*lexer).i } | TableAsName { diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 535dfcddea..528e360754 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -200,7 +200,14 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se return nil, errors.Trace(err) } } else if val.WildCard != nil { - oldField.Expr = &expression.Ident{CIStr: model.NewCIStr("*")} + str := "*" + if val.WildCard.Table.O != "" { + str = val.WildCard.Table.O + ".*" + if val.WildCard.Schema.O != "" { + str = val.WildCard.Schema.O + "." + str + } + } + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr(str)} } oldSelect.Fields[i] = oldField } From ba9dfc3551e2b4ddf6afb6cbc0c8b9421082bb2e Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 10:00:17 +0800 Subject: [PATCH 15/27] ast/parser: add text to select field. --- ast/ast.go | 2 ++ ast/dml.go | 2 ++ ast/parser/parser.y | 64 +++++++++++++++++++++++++++++++++++---- ast/parser/scanner.l | 10 ++++++ optimizer/convert_stmt.go | 3 ++ 5 files changed, 75 insertions(+), 6 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index fc505c8362..6b699d796e 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -97,10 +97,12 @@ type ResultSetNode interface { // Visitor visits a Node. type Visitor interface { // VisitEnter is called before children nodes is visited. + // The returned node should be the same type as the input node n. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. + // The returned node should be the same type as the input node n. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) } diff --git a/ast/dml.go b/ast/dml.go index afd836ef66..b45fc6374d 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -203,6 +203,8 @@ func (nod *WildCardField) Accept(v Visitor) (Node, bool) { type SelectField struct { node + // Offset is used to get original text. + Offset int // If WildCard is not nil, Expr will be nil. WildCard *WildCardField // If Expr is not nil, WildCard will be nil. diff --git a/ast/parser/parser.y b/ast/parser/parser.y index a35ea0d086..1d6b05e314 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1516,7 +1516,15 @@ Field: } | Expression FieldAsNameOpt { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: model.NewCIStr($2.(string))} + expr := $1.(ast.ExprNode) + asName := $2.(string) + if asName != "" { + // Set expr original text. + offset := yyS[yypt-1].offset + end := yyS[yypt].offset-1 + expr.SetText(yylex.(*lexer).src[offset:end]) + } + $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} } FieldAsNameOpt: @@ -1550,11 +1558,22 @@ FieldAsName: FieldList: Field { - $$ = []*ast.SelectField{$1.(*ast.SelectField)} + field := $1.(*ast.SelectField) + field.Offset = yyS[yypt].offset + $$ = []*ast.SelectField{field} } | FieldList ',' Field { - $$ = append($1.([]*ast.SelectField), $3.(*ast.SelectField)) + + fl := $1.([]*ast.SelectField) + last := fl[len(fl)-1] + if last.Expr != nil && last.AsName.O == "" { + lastEnd := yyS[yypt-1].offset-1 // Comma offset. + last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) + } + newField := $3.(*ast.SelectField) + newField.Offset = yyS[yypt].offset + $$ = append(fl, newField) } GroupByClause: @@ -2628,6 +2647,22 @@ SelectStmt: Fields: $3.(*ast.FieldList), LockTp: $5.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + src := yylex.(*lexer).src + var lastEnd int + if $4 != nil { + lastEnd = yyS[yypt-1].offset-1 + } else if $5 != ast.SelectLockNone { + lastEnd = yyS[yypt].offset-1 + } else { + lastEnd = len(src) + if src[lastEnd-1] == ';' { + lastEnd-- + } + } + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } if $4 != nil { st.Limit = $4.(*ast.Limit) } @@ -2640,6 +2675,11 @@ SelectStmt: Fields: $3.(*ast.FieldList), LockTp: $7.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-3].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } if $5 != nil { st.Where = $5.(ast.ExprNode) } @@ -2659,6 +2699,12 @@ SelectStmt: LockTp: $11.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-7].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } + if $6 != nil { st.Where = $6.(ast.ExprNode) } @@ -2739,6 +2785,8 @@ TableFactor: } | '(' SelectStmt ')' TableAsName { + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } | '(' UnionStmt ')' TableAsName @@ -2753,7 +2801,6 @@ TableFactor: TableAsNameOpt: { $$ = model.CIStr{} - yyS[yypt].offset = yylex.(*lexer).i } | TableAsName { @@ -2886,6 +2933,7 @@ SubSelect: '(' SelectStmt ')' { s := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) @@ -2928,7 +2976,9 @@ UnionStmt: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) - union.Selects = append(union.Selects, $5.(*ast.SelectStmt)) + st := $5.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, st) if $7 != nil { union.OrderBy = $7.(*ast.OrderByClause) } @@ -2958,7 +3008,9 @@ UnionSelect: SelectStmt | '(' SelectStmt ')' { - $$ = $2 + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) + $$ = st } UnionOpt: diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index adaa662651..b58a8f6f3e 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -118,6 +118,16 @@ func (l *lexer) GetCharsetInfo() (string, string) { return l.charset, l.collation } +// The select statement is not at the end of the whole statement, if the last +// field text was set from its offset to the end of the src string, update +// the last field text. +func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 { + lastField.SetText(l.src[lastField.Offset:lastEnd]) + } +} + func (l *lexer) unget(b byte) { l.ungetBuf = append(l.ungetBuf, b) l.i-- diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 528e360754..321cdada18 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -199,6 +199,9 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } + if oldField.AsName == "" { + oldField.AsName = val.Text() + } } else if val.WildCard != nil { str := "*" if val.WildCard.Table.O != "" { From 53a11459c2d408f3d9ea3a7785f8f9ed87eae4a8 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 10:28:33 +0800 Subject: [PATCH 16/27] optimizer: only set old field AsName when it is not identifier --- optimizer/convert_stmt.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 321cdada18..875eb62aea 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -199,7 +199,7 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } - if oldField.AsName == "" { + if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { oldField.AsName = val.Text() } } else if val.WildCard != nil { From 5ca94b542e55b182c4c836f5c38671eb24e71dc4 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 11:18:34 +0800 Subject: [PATCH 17/27] optimizer: add comments for set field text. --- optimizer/convert_stmt.go | 1 + 1 file changed, 1 insertion(+) diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 875eb62aea..98459ff694 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -199,6 +199,7 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } + // TODO: handle parenthesesed column name expression, which should not set AsName. if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { oldField.AsName = val.Text() } From a74b373a5271df29bd3da9ab8c24eb5e36ea8548 Mon Sep 17 00:00:00 2001 From: Shen Li Date: Tue, 3 Nov 2015 11:24:47 +0800 Subject: [PATCH 18/27] parser: Add charset info for string literal --- ast/parser/parser.y | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 127a3b965d..da9e510966 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1797,6 +1797,15 @@ Literal: | floatLit | intLit | stringLit + { + tp := types.NewFieldType(mysql.TypeString) + l := yylex.(*lexer) + tp.Charset, tp.Collate = l.GetCharsetInfo() + $$ = &types.DataItem{ + Type: tp, + Data: $1.(string), + } + } | "UNDERSCORE_CHARSET" stringLit { // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html From 6b8a39523f85b96fbf324fd9efd8a952cdd3d845 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 3 Nov 2015 13:53:46 +0800 Subject: [PATCH 19/27] ast/parser: add tests and fix bugs. --- ast/dml.go | 39 +-- ast/functions.go | 47 ++++ ast/parser/parser.y | 79 +++++- ast/parser/parser_test.go | 568 ++++++++++++++++++++++++++++++++++++++ ast/parser/scanner.l | 8 + optimizer/convert_expr.go | 17 ++ optimizer/convert_stmt.go | 26 +- 7 files changed, 734 insertions(+), 50 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index b45fc6374d..f8cbd01c11 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -642,7 +642,7 @@ type DeleteStmt struct { // Only used in multiple table delete statement. Tables []*TableName Where ExprNode - Order []*ByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -680,20 +680,20 @@ func (nod *DeleteStmt) Accept(v Visitor) (Node, bool) { } nod.Where = node.(ExprNode) } - - for i, val := range nod.Order { - node, ok = val.Accept(v) + if nod.Order != nil { + node, ok = nod.Order.Accept(v) if !ok { return nod, false } - nod.Order[i] = node.(*ByItem) + nod.Order = node.(*OrderByClause) } - - node, ok = nod.Limit.Accept(v) - if !ok { - return nod, false + if nod.Limit != nil { + node, ok = nod.Limit.Accept(v) + if !ok { + return nod, false + } + nod.Limit = node.(*Limit) } - nod.Limit = node.(*Limit) return v.Leave(nod) } @@ -705,7 +705,7 @@ type UpdateStmt struct { TableRefs *TableRefsClause List []*Assignment Where ExprNode - Order []*ByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -738,19 +738,20 @@ func (nod *UpdateStmt) Accept(v Visitor) (Node, bool) { } nod.Where = node.(ExprNode) } - - for i, val := range nod.Order { - node, ok = val.Accept(v) + if nod.Order != nil { + node, ok = nod.Order.Accept(v) if !ok { return nod, false } - nod.Order[i] = node.(*ByItem) + nod.Order = node.(*OrderByClause) } - node, ok = nod.Limit.Accept(v) - if !ok { - return nod, false + if nod.Limit != nil { + node, ok = nod.Limit.Accept(v) + if !ok { + return nod, false + } + nod.Limit = node.(*Limit) } - nod.Limit = node.(*Limit) return v.Leave(nod) } diff --git a/ast/functions.go b/ast/functions.go index e3323ef3d1..8d688f9eee 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -28,6 +28,7 @@ var ( _ FuncNode = &FuncSubstringExpr{} _ FuncNode = &FuncLocateExpr{} _ FuncNode = &FuncTrimExpr{} + _ FuncNode = &FuncDateArithExpr{} _ FuncNode = &AggregateFuncExpr{} ) @@ -337,6 +338,52 @@ func (nod *FuncTrimExpr) IsStatic() bool { return nod.Str.IsStatic() && nod.RemStr.IsStatic() } +// DateArithType is type for DateArith option. +type DateArithType byte + +const ( + // DateAdd is to run date_add function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add + DateAdd DateArithType = iota + 1 + // DateSub is to run date_sub function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub + DateSub +) + +// FuncDateArithExpr is the struct for date arithmetic functions. +type FuncDateArithExpr struct { + funcNode + + Op DateArithType + Unit string + Date ExprNode + Interval ExprNode +} + +// Accept implements Node Accept interface. +func (nod *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(nod) + if skipChildren { + return v.Leave(newNod) + } + nod = newNod.(*FuncDateArithExpr) + if nod.Date != nil { + node, ok := nod.Date.Accept(v) + if !ok { + return nod, false + } + nod.Date = node.(ExprNode) + } + if nod.Interval != nil { + node, ok := nod.Date.Accept(v) + if !ok { + return nod, false + } + nod.Date = node.(ExprNode) + } + return v.Leave(nod) +} + // AggregateFuncExpr represents aggregate function expression. type AggregateFuncExpr struct { funcNode diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 1d6b05e314..4175b2c1f8 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -104,6 +104,8 @@ import ( currentUser "CURRENT_USER" database "DATABASE" databases "DATABASES" + dateAdd "DATE_ADD" + dateSub "DATE_SUB" day "DAY" dayofmonth "DAYOFMONTH" dayofweek "DAYOFWEEK" @@ -156,6 +158,7 @@ import ( index "INDEX" inner "INNER" insert "INSERT" + interval "INTERVAL" into "INTO" is "IS" join "JOIN" @@ -839,7 +842,7 @@ ColumnOption: { // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html // The CHECK clause is parsed but ignored by all storage engines. - $$ = nil + $$ = &ast.ColumnOption{} } ColumnOptionList: @@ -1159,7 +1162,7 @@ DeleteFromStmt: x.Where = $7.(ast.ExprNode) } if $8 != nil { - x.Order = $8.([]*ast.ByItem) + x.Order = $8.(*ast.OrderByClause) } if $9 != nil { x.Limit = $9.(*ast.Limit) @@ -1650,7 +1653,7 @@ UnReservedKeyword: | "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" + "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" | "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" @@ -1689,12 +1692,17 @@ InsertValues: { $$ = &ast.InsertStmt{ Columns: $2.([]*ast.ColumnName), - Lists: $5.([][]ast.ExprNode)} + Lists: $5.([][]ast.ExprNode), + } } | '(' ColumnNameListOpt ')' SelectStmt { $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} } +| '(' ColumnNameListOpt ')' UnionStmt + { + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.UnionStmt)} + } | ValueSym ExpressionListList %prec insertValues { $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} @@ -2111,6 +2119,24 @@ FunctionCallNonKeyword: { $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } +| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateAdd, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), + } + } +| "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateSub, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), + } + } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { $$ = &ast.FuncExtractExpr{ @@ -2969,6 +2995,8 @@ UnionStmt: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) $$ = union } @@ -2976,6 +3004,8 @@ UnionStmt: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) st := $5.(*ast.SelectStmt) yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) union.Selects = append(union.Selects, st) @@ -3000,6 +3030,8 @@ UnionClauseList: { union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) $$ = union } @@ -3216,7 +3248,7 @@ ShowStmt: } if x, ok := $4.(*ast.PatternLikeExpr); ok { stmt.Pattern = x - } else { + } else if $4 != nil { stmt.Where = $4.(ast.ExprNode) } $$ = stmt @@ -3228,7 +3260,7 @@ ShowStmt: } if x, ok := $3.(*ast.PatternLikeExpr); ok { stmt.Pattern = x - } else { + } else if $3 != nil { stmt.Where = $3.(ast.ExprNode) } $$ = stmt @@ -3970,7 +4002,34 @@ StringName: * See: https://dev.mysql.com/doc/refman/5.7/en/update.html ***********************************************************************************/ UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause + { + var refs *ast.Join + if x, ok := $4.(*ast.Join); ok { + refs = x + } else { + refs = &ast.Join{Left: $4.(ast.ResultSetNode)} + } + st := &ast.UpdateStmt{ + LowPriority: $2.(bool), + TableRefs: &ast.TableRefsClause{TableRefs: refs}, + List: $6.([]*ast.Assignment), + } + if $7 != nil { + st.Where = $7.(ast.ExprNode) + } + if $8 != nil { + st.Order = $8.(*ast.OrderByClause) + } + if $9 != nil { + st.Limit = $9.(*ast.Limit) + } + $$ = st + if yylex.(*lexer).root { + break + } + } +| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional { st := &ast.UpdateStmt{ LowPriority: $2.(bool), @@ -3980,12 +4039,6 @@ UpdateStmt: if $7 != nil { st.Where = $7.(ast.ExprNode) } - if $8 != nil { - st.Order = $8.([]*ast.ByItem) - } - if $9 != nil { - st.Limit = $9.(*ast.Limit) - } $$ = st if yylex.(*lexer).root { break diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go index 1e2798c28c..3e8d7bb278 100644 --- a/ast/parser/parser_test.go +++ b/ast/parser/parser_test.go @@ -77,3 +77,571 @@ func (s *testParserSuite) TestSimple(c *C) { c.Assert(ok, IsTrue) } } + +type testCase struct { + src string + ok bool +} + +func (s *testParserSuite) RunTest(c *C, table []testCase) { + for _, t := range table { + l := NewLexer(t.src) + ok := yyParse(l) == 0 + c.Assert(ok, Equals, t.ok, Commentf("source %v %v", t.src, l.errs)) + switch ok { + case true: + c.Assert(l.errs, HasLen, 0, Commentf("src: %s", t.src)) + case false: + c.Assert(len(l.errs), Not(Equals), 0, Commentf("src: %s", t.src)) + } + } +} +func (s *testParserSuite) TestDMLStmt(c *C) { + table := []testCase{ + {"", true}, + {";", true}, + {"INSERT INTO foo VALUES (1234)", true}, + {"INSERT INTO foo VALUES (1234, 5678)", true}, + // 15 + {"INSERT INTO foo VALUES (1 || 2)", true}, + {"INSERT INTO foo VALUES (1 | 2)", true}, + {"INSERT INTO foo VALUES (false || true)", true}, + {"INSERT INTO foo VALUES (bar(5678))", false}, + // 20 + {"INSERT INTO foo VALUES ()", true}, + {"SELECT * FROM t", true}, + {"SELECT * FROM t AS u", true}, + // 25 + {"SELECT * FROM t, v", true}, + {"SELECT * FROM t AS u, v", true}, + {"SELECT * FROM t, v AS w", true}, + {"SELECT * FROM t AS u, v AS w", true}, + {"SELECT * FROM foo, bar, foo", true}, + // 30 + {"SELECT DISTINCTS * FROM t", false}, + {"SELECT DISTINCT * FROM t", true}, + {"INSERT INTO foo (a) VALUES (42)", true}, + {"INSERT INTO foo (a,) VALUES (42,)", false}, + // 35 + {"INSERT INTO foo (a,b) VALUES (42,314)", true}, + {"INSERT INTO foo (a,b,) VALUES (42,314)", false}, + {"INSERT INTO foo (a,b,) VALUES (42,314,)", false}, + {"INSERT INTO foo () VALUES ()", true}, + {"INSERT INTO foo VALUE ()", true}, + + {"REPLACE INTO foo VALUES (1 || 2)", true}, + {"REPLACE INTO foo VALUES (1 | 2)", true}, + {"REPLACE INTO foo VALUES (false || true)", true}, + {"REPLACE INTO foo VALUES (bar(5678))", false}, + {"REPLACE INTO foo VALUES ()", true}, + {"REPLACE INTO foo (a,b) VALUES (42,314)", true}, + {"REPLACE INTO foo (a,b,) VALUES (42,314)", false}, + {"REPLACE INTO foo (a,b,) VALUES (42,314,)", false}, + {"REPLACE INTO foo () VALUES ()", true}, + {"REPLACE INTO foo VALUE ()", true}, + // 40 + {`SELECT stuff.id + FROM stuff + WHERE stuff.value >= ALL (SELECT stuff.value + FROM stuff)`, true}, + {"BEGIN", true}, + {"START TRANSACTION", true}, + // 45 + {"COMMIT", true}, + {"ROLLBACK", true}, + {` + BEGIN; + INSERT INTO foo VALUES (42, 3.14); + INSERT INTO foo VALUES (-1, 2.78); + COMMIT;`, true}, + {` // A + BEGIN; + INSERT INTO tmp SELECT * from bar; + SELECT * from tmp; + + // B + ROLLBACK;`, true}, + + // set + // user defined + {"SET @a = 1", true}, + // session system variables + {"SET SESSION autocommit = 1", true}, + {"SET @@session.autocommit = 1", true}, + {"SET LOCAL autocommit = 1", true}, + {"SET @@local.autocommit = 1", true}, + {"SET @@autocommit = 1", true}, + {"SET autocommit = 1", true}, + // global system variables + {"SET GLOBAL autocommit = 1", true}, + {"SET @@global.autocommit = 1", true}, + // SET CHARACTER SET + {"SET CHARACTER SET utf8mb4;", true}, + {"SET CHARACTER SET 'utf8mb4';", true}, + // Set password + {"SET PASSWORD = 'password';", true}, + {"SET PASSWORD FOR 'root'@'localhost' = 'password';", true}, + + // qualified select + {"SELECT a.b.c FROM t", true}, + {"SELECT a.b.*.c FROM t", false}, + {"SELECT a.b.* FROM t", true}, + {"SELECT a FROM t", true}, + {"SELECT a.b.c.d FROM t", false}, + + // Do statement + {"DO 1", true}, + {"DO 1 from t", false}, + + // Select for update + {"SELECT * from t for update", true}, + {"SELECT * from t lock in share mode", true}, + + // For alter table + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED", true}, + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED FIRST", true}, + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED AFTER b", true}, + + // from join + {"SELECT * from t1, t2, t3", true}, + {"select * from t1 join t2 left join t3 on t2.id = t3.id", true}, + {"select * from t1 right join t2 on t1.id = t2.id left join t3 on t3.id = t2.id", true}, + {"select * from t1 right join t2 on t1.id = t2.id left join t3", false}, + + // For show full columns + {"show columns in t;", true}, + {"show full columns in t;", true}, + + // For set names + {"set names utf8", true}, + {"set names utf8 collate utf8_unicode_ci", true}, + + // For show character set + {"show character set;", true}, + // For on duplicate key update + {"INSERT INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, + {"INSERT IGNORE INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, + + // For SHOW statement + {"SHOW VARIABLES LIKE 'character_set_results'", true}, + {"SHOW GLOBAL VARIABLES LIKE 'character_set_results'", true}, + {"SHOW SESSION VARIABLES LIKE 'character_set_results'", true}, + {"SHOW VARIABLES", true}, + {"SHOW GLOBAL VARIABLES", true}, + {"SHOW GLOBAL VARIABLES WHERE Variable_name = 'autocommit'", true}, + {`SHOW FULL TABLES FROM icar_qa LIKE play_evolutions`, true}, + {`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true}, + {`SHOW GRANTS`, true}, + {`SHOW GRANTS FOR 'test'@'localhost'`, true}, + + // For default value + {"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true}, + + // For delete statement + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, + + // For update statement + {"UPDATE t SET id = id + 1 ORDER BY id DESC;", true}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id;", true}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id LIMIT 10;", false}, + {"UPDATE user T0 LEFT OUTER JOIN user_profile T1 ON T1.id = T0.profile_id SET T0.profile_id = 1 WHERE T0.profile_id IN (1);", true}, + + // For select with where clause + {"SELECT * FROM t WHERE 1 = 1", true}, + + // For show collation + {"show collation", true}, + {"show collation like 'utf8%'", true}, + {"show collation where Charset = 'utf8' and Collation = 'utf8_bin'", true}, + + // For dual + {"select 1 from dual", true}, + {"select 1 from dual limit 1", true}, + {"select 1 where exists (select 2)", false}, + {"select 1 from dual where not exists (select 2)", true}, + + // For show create table + {"show create table test.t", true}, + {"show create table t", true}, + + // For https://github.com/pingcap/tidb/issues/320 + {`(select 1);`, true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestExpression(c *C) { + table := []testCase{ + // Sign expression + {"SELECT ++1", true}, + {"SELECT -*1", false}, + {"SELECT -+1", true}, + {"SELECT -1", true}, + {"SELECT --1", true}, + + // For string literal + {`select '''a''', """a"""`, true}, + {`select ''a''`, false}, + {`select ""a""`, false}, + {`select '''a''';`, true}, + {`select '\'a\'';`, true}, + {`select "\"a\"";`, true}, + {`select """a""";`, true}, + {`select _utf8"string";`, true}, + // For comparison + {"select 1 <=> 0, 1 <=> null, 1 = null", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestBuiltin(c *C) { + table := []testCase{ + // For buildin functions + {"SELECT DAYOFMONTH('2007-02-03');", true}, + {"SELECT RAND();", true}, + {"SELECT RAND(1);", true}, + + {"SELECT SUBSTRING('Quadratically',5);", true}, + {"SELECT SUBSTRING('Quadratically',5, 3);", true}, + {"SELECT SUBSTRING('Quadratically' FROM 5);", true}, + {"SELECT SUBSTRING('Quadratically' FROM 5 FOR 3);", true}, + + {"SELECT CONVERT('111', SIGNED);", true}, + + {"SELECT DATABASE();", true}, + {"SELECT USER();", true}, + {"SELECT CURRENT_USER();", true}, + {"SELECT CURRENT_USER;", true}, + + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', 2);", true}, + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', -2);", true}, + + {`SELECT LOWER("A"), UPPER("a")`, true}, + + {`SELECT REPLACE('www.mysql.com', 'w', 'Ww')`, true}, + + {`SELECT LOCATE('bar', 'foobarbar');`, true}, + {`SELECT LOCATE('bar', 'foobarbar', 5);`, true}, + + {"select current_date, current_date(), curdate()", true}, + + // For delete statement + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, + + // For time fsp + {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, + + // For row + {"select row(1)", false}, + {"select row(1, 1,)", false}, + {"select (1, 1,)", false}, + {"select row(1, 1) > row(1, 1), row(1, 1, 1) > row(1, 1, 1)", true}, + {"Select (1, 1) > (1, 1)", true}, + {"create table t (row int)", true}, + + // For cast with charset + {"SELECT *, CAST(data AS CHAR CHARACTER SET utf8) FROM t;", true}, + + // For binary operator + {"SELECT binary 'a';", true}, + + // Select time + {"select current_timestamp", true}, + {"select current_timestamp()", true}, + {"select current_timestamp(6)", true}, + {"select now()", true}, + {"select now(6)", true}, + {"select sysdate(), sysdate(6)", true}, + + // For time extract + {`select extract(microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(week from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(month from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(quarter from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(year from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(second_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_hour from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(year_month from "2011-11-11 10:10:10.123456")`, true}, + + // For issue 224 + {`SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8) COLLATE utf8_bin;`, true}, + + // For trim + {`SELECT TRIM(' bar ');`, true}, + {`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true}, + {`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true}, + {`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true}, + + // For date_add + {`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + + // For date_sub + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestIdentifier(c *C) { + table := []testCase{ + // For quote identifier + {"select `a`, `a.b`, `a b` from t", true}, + // For unquoted identifier + {"create table MergeContextTest$Simple (value integer not null, primary key (value))", true}, + // For as + {"select 1 as a, 1 as `a`, 1 as \"a\", 1 as 'a'", true}, + {`select 1 as a, 1 as "a", 1 as 'a'`, true}, + {`select 1 a, 1 "a", 1 'a'`, true}, + {`select * from t as "a"`, false}, + {`select * from t a`, true}, + {`select * from t as a`, true}, + {"select 1 full, 1 row, 1 abs", true}, + {"select * from t full, t1 row, t2 abs", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestDDL(c *C) { + table := []testCase{ + {"CREATE", false}, + {"CREATE TABLE", false}, + {"CREATE TABLE foo (", false}, + {"CREATE TABLE foo ()", false}, + {"CREATE TABLE foo ();", false}, + {"CREATE TABLE foo (a TINYINT UNSIGNED);", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, + {"CREATE TABLE foo (a bigint unsigned, b bool);", true}, + {"CREATE TABLE foo (a TINYINT, b SMALLINT) CREATE TABLE bar (x INT, y int64)", false}, + {"CREATE TABLE foo (a int, b float); CREATE TABLE bar (x double, y float)", true}, + {"CREATE TABLE foo (a bytes)", false}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) -- foo", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) // foo", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, + {"CREATE TABLE foo /* foo */ (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, + {"CREATE TABLE foo (name CHAR(50) BINARY)", true}, + {"CREATE TABLE foo (name CHAR(50) COLLATE utf8_bin)", true}, + {"CREATE TABLE foo (name CHAR(50) CHARACTER SET utf8)", true}, + {"CREATE TABLE foo (name CHAR(50) BINARY CHARACTER SET utf8 COLLATE utf8_bin)", true}, + + {"CREATE TABLE foo (a.b, b);", false}, + {"CREATE TABLE foo (a, b.c);", false}, + // For table option + {"create table t (c int) avg_row_length = 3", true}, + {"create table t (c int) avg_row_length 3", true}, + {"create table t (c int) checksum = 0", true}, + {"create table t (c int) checksum 1", true}, + {"create table t (c int) compression = none", true}, + {"create table t (c int) compression lz4", true}, + {"create table t (c int) connection = 'abc'", true}, + {"create table t (c int) connection 'abc'", true}, + {"create table t (c int) key_block_size = 1024", true}, + {"create table t (c int) key_block_size 1024", true}, + {"create table t (c int) max_rows = 1000", true}, + {"create table t (c int) max_rows 1000", true}, + {"create table t (c int) min_rows = 1000", true}, + {"create table t (c int) min_rows 1000", true}, + {"create table t (c int) password = 'abc'", true}, + {"create table t (c int) password 'abc'", true}, + // For check clause + {"create table t (c1 bool, c2 bool, check (c1 in (0, 1)), check (c2 in (0, 1)))", true}, + {"CREATE TABLE Customer (SD integer CHECK (SD > 0), First_Name varchar(30));", true}, + + {"create database xxx", true}, + {"create database if exists xxx", false}, + {"create database if not exists xxx", true}, + {"create schema xxx", true}, + {"create schema if exists xxx", false}, + {"create schema if not exists xxx", true}, + // For drop datbase/schema/table + {"drop database xxx", true}, + {"drop database if exists xxx", true}, + {"drop database if not exists xxx", false}, + {"drop schema xxx", true}, + {"drop schema if exists xxx", true}, + {"drop schema if not exists xxx", false}, + {"drop table xxx", true}, + {"drop table xxx, yyy", true}, + {"drop tables xxx", true}, + {"drop tables xxx, yyy", true}, + {"drop table if exists xxx", true}, + {"drop table if not exists xxx", false}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestType(c *C) { + table := []testCase{ + // For time fsp + {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, + + // For hexadecimal + {"SELECT x'0a', X'11', 0x11", true}, + {"select x'0xaa'", false}, + {"select 0X11", false}, + + // For bit + {"select 0b01, 0b0, b'11', B'11'", true}, + {"select 0B01", false}, + {"select 0b21", false}, + + // For enum and set type + {"create table t (c1 enum('a', 'b'), c2 set('a', 'b'))", true}, + {"create table t (c1 enum)", false}, + {"create table t (c1 set)", false}, + + // For blob and text field length + {"create table t (c1 blob(1024), c2 text(1024))", true}, + + // For year + {"create table t (y year(4), y1 year)", true}, + + // For national + {"create table t (c1 national char(2), c2 national varchar(2))", true}, + + // For https://github.com/pingcap/tidb/issues/312 + {`create table t (c float(53));`, true}, + {`create table t (c float(54));`, false}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestPrivilege(c *C) { + table := []testCase{ + // For create user + {`CREATE USER IF NOT EXISTS 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password', 'root'@'127.0.0.1' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + + // For grant statement + {"GRANT ALL ON db1.* TO 'jeffrey'@'localhost';", true}, + {"GRANT SELECT ON db2.invoice TO 'jeffrey'@'localhost';", true}, + {"GRANT ALL ON *.* TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON *.* TO 'someuser'@'somehost';", true}, + {"GRANT ALL ON mydb.* TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON mydb.* TO 'someuser'@'somehost';", true}, + {"GRANT ALL ON mydb.mytbl TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON mydb.mytbl TO 'someuser'@'somehost';", true}, + {"GRANT SELECT (col1), INSERT (col1,col2) ON mydb.mytbl TO 'someuser'@'somehost';", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestComment(c *C) { + table := []testCase{ + {"create table t (c int comment 'comment')", true}, + {"create table t (c int) comment = 'comment'", true}, + {"create table t (c int) comment 'comment'", true}, + {"create table t (c int) comment comment", false}, + {"create table t (comment text)", true}, + // For comment in query + {"/*comment*/ /*comment*/ select c /* this is a comment */ from t;", true}, + } + s.RunTest(c, table) +} +func (s *testParserSuite) TestSubquery(c *C) { + table := []testCase{ + // For compare subquery + {"SELECT 1 > (select 1)", true}, + {"SELECT 1 > ANY (select 1)", true}, + {"SELECT 1 > ALL (select 1)", true}, + {"SELECT 1 > SOME (select 1)", true}, + + // For exists subquery + {"SELECT EXISTS select 1", false}, + {"SELECT EXISTS (select 1)", true}, + {"SELECT + EXISTS (select 1)", true}, + {"SELECT - EXISTS (select 1)", true}, + {"SELECT NOT EXISTS (select 1)", true}, + {"SELECT + NOT EXISTS (select 1)", false}, + {"SELECT - NOT EXISTS (select 1)", false}, + } + s.RunTest(c, table) +} +func (s *testParserSuite) TestUnion(c *C) { + table := []testCase{ + {"select c1 from t1 union select c2 from t2", true}, + {"select c1 from t1 union (select c2 from t2)", true}, + {"select c1 from t1 union (select c2 from t2) order by c1", true}, + {"select c1 from t1 union select c2 from t2 order by c2", true}, + {"select c1 from t1 union (select c2 from t2) limit 1", true}, + {"select c1 from t1 union (select c2 from t2) limit 1, 1", true}, + {"select c1 from t1 union (select c2 from t2) order by c1 limit 1", true}, + {"(select c1 from t1) union distinct select c2 from t2", true}, + {"(select c1 from t1) union all select c2 from t2", true}, + {"(select c1 from t1) union (select c2 from t2) order by c1 union select c3 from t3", false}, + {"(select c1 from t1) union (select c2 from t2) limit 1 union select c3 from t3", false}, + {"(select c1 from t1) union select c2 from t2 union (select c3 from t3) order by c1 limit 1", true}, + {"select (select 1 union select 1) as a", true}, + {"select * from (select 1 union select 2) as a", true}, + {"insert into t select c1 from t1 union select c2 from t2", true}, + {"insert into t (c) select c1 from t1 union select c2 from t2", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestLikeEscape(c *C) { + table := []testCase{ + // For like escape + {`select "abc_" like "abc\\_" escape ''`, true}, + {`select "abc_" like "abc\\_" escape '\\'`, true}, + {`select "abc_" like "abc\\_" escape '||'`, false}, + {`select "abc" like "escape" escape '+'`, true}, + } + + s.RunTest(c, table) +} diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index b58a8f6f3e..9ee7b1b8a9 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -307,6 +307,8 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} database {d}{a}{t}{a}{b}{a}{s}{e} databases {d}{a}{t}{a}{b}{a}{s}{e}{s} +date_add {d}{a}{t}{e}_{a}{d}{d} +date_sub {d}{a}{t}{e}_{s}{u}{b} day {d}{a}{y} dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} @@ -355,6 +357,7 @@ in {i}{n} index {i}{n}{d}{e}{x} inner {i}{n}{n}{e}{r} insert {i}{n}{s}{e}{r}{t} +interval {i}{n}{t}{e}{r}{v}{a}{l} into {i}{n}{t}{o} is {i}{s} join {j}{o}{i}{n} @@ -655,6 +658,10 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {database} lval.item = string(l.val) return database {databases} return databases +{date_add} lval.item = string(l.val) + return dateAdd +{date_sub} lval.item = string(l.val) + return dateSub {day} lval.item = string(l.val) return day {dayofweek} lval.item = string(l.val) @@ -738,6 +745,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {index} return index {inner} return inner {insert} return insert +{interval} return interval {into} return into {in} return in {is} return is diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index d43f93f016..a136819ae1 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -112,6 +112,8 @@ func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) { c.funcLocate(v) case *ast.FuncTrimExpr: c.funcTrim(v) + case *ast.FuncDateArithExpr: + c.funcDateArith(v) case *ast.AggregateFuncExpr: c.aggregateFunc(v) } @@ -396,6 +398,21 @@ func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) { c.exprMap[v] = oldTrim } +func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) { + oldDateArith := &expression.DateArith{ + Unit: v.Unit, + Date: c.exprMap[v.Date], + Interval: c.exprMap[v.Interval], + } + switch v.Op { + case ast.DateAdd: + oldDateArith.Op = expression.DateAdd + case ast.DateSub: + oldDateArith.Op = expression.DateSub + } + c.exprMap[v] = oldDateArith +} + func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) { oldAggregate := &expression.Call{ F: v.F, diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 98459ff694..df769237f1 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -123,16 +123,11 @@ func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.De } } if v.Order != nil { - orderRset := &rsets.OrderByRset{} - for _, val := range v.Order { - oldExpr, err := convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) - } - orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} - orderRset.By = append(orderRset.By, orderItem) + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) } - oldDelete.Order = orderRset + oldDelete.Order = oldOrderBy } if v.Limit != nil { oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} @@ -167,16 +162,11 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up oldUpdate.List = append(oldUpdate.List, oldAssign) } if v.Order != nil { - orderRset := &rsets.OrderByRset{} - for _, val := range v.Order { - oldExpr, err := convertExpr(converter, val.Expr) - if err != nil { - return nil, errors.Trace(err) - } - orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} - orderRset.By = append(orderRset.By, orderItem) + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) } - oldUpdate.Order = orderRset + oldUpdate.Order = oldOrderBy } if v.Limit != nil { oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} From fce519486f0d2ba45a8554de630faf2f6ab74d62 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 4 Nov 2015 17:00:38 +0800 Subject: [PATCH 20/27] ast: visit Select element first in InsertStmt. Address comment. --- ast/dml.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index f8cbd01c11..0b35a70d0e 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -586,6 +586,13 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { } nod = newNod.(*InsertStmt) + if nod.Select != nil { + node, ok := nod.Select.Accept(v) + if !ok { + return nod, false + } + nod.Select = node.(ResultSetNode) + } node, ok := nod.Table.Accept(v) if !ok { return nod, false @@ -622,13 +629,6 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { } nod.OnDuplicate[i] = node.(*Assignment) } - if nod.Select != nil { - node, ok := nod.Select.Accept(v) - if !ok { - return nod, false - } - nod.Select = node.(ResultSetNode) - } return v.Leave(nod) } From b8a7e686317b98a650a4e66fb5c8d0c3fdec0c55 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 4 Nov 2015 17:59:12 +0800 Subject: [PATCH 21/27] ast: add comments about implementation rule on Accept method. --- ast/ast.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ast/ast.go b/ast/ast.go index 6b699d796e..a10d99dbd7 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -26,6 +26,11 @@ type Node interface { // Accept accepts Visitor to visit itself. // The returned node should replace original node. // ok returns false to stop visiting. + // + // Implementation of this method should first call visitor.Enter, + // assign the returned node to its method receiver, if skipChildren returns true, + // children should be skipped. Otherwise, call its children in particular order that + // later elements depends on on former elements. Finally, return visitor.Leave. Accept(v Visitor) (node Node, ok bool) // Text returns the original text of the element. Text() string From 524e18dff1adb8a0cc4af1479f9bd94bfa6376d5 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Wed, 4 Nov 2015 19:51:33 +0800 Subject: [PATCH 22/27] ast: address comment. --- ast/ast.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index a10d99dbd7..a1021420f1 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -102,12 +102,12 @@ type ResultSetNode interface { // Visitor visits a Node. type Visitor interface { // VisitEnter is called before children nodes is visited. - // The returned node should be the same type as the input node n. + // The returned node must be the same type as the input node n. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. - // The returned node should be the same type as the input node n. + // The returned node must be the same type as the input node n. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) } From c623fafc727627eac306bc34f8dc44ed3323b441 Mon Sep 17 00:00:00 2001 From: ngaut Date: Thu, 5 Nov 2015 13:53:32 +0800 Subject: [PATCH 23/27] parser: clean up --- ast/parser/parser.y | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ast/parser/parser.y b/ast/parser/parser.y index f61843b28d..145d5ae0e7 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -2696,7 +2696,7 @@ SelectStmt: lastEnd-- } } - lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + lastField.SetText(src[lastField.Offset:lastEnd]) } if $4 != nil { st.Limit = $4.(*ast.Limit) From 9be82dbe27e10f528aa913c1c34dc0df3d55ebda Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 5 Nov 2015 12:44:18 +0800 Subject: [PATCH 24/27] ast: address comment. --- ast/ast.go | 2 +- ast/cloner.go | 8 ++++---- ast/parser/parser.y | 2 +- optimizer/aggregator.go | 4 ++-- optimizer/typecomputer.go | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ast/ast.go b/ast/ast.go index a1021420f1..c38c8bc32d 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -30,7 +30,7 @@ type Node interface { // Implementation of this method should first call visitor.Enter, // assign the returned node to its method receiver, if skipChildren returns true, // children should be skipped. Otherwise, call its children in particular order that - // later elements depends on on former elements. Finally, return visitor.Leave. + // later elements depends on former elements. Finally, return visitor.Leave. Accept(v Visitor) (node Node, ok bool) // Text returns the original text of the element. Text() string diff --git a/ast/cloner.go b/ast/cloner.go index be5d7e441e..d3d8d282e5 100644 --- a/ast/cloner.go +++ b/ast/cloner.go @@ -21,7 +21,7 @@ type Cloner struct { // Enter implements Visitor Enter interface. func (c *Cloner) Enter(node Node) (Node, bool) { - return cloneStruct(node), false + return copyStruct(node), false } // Leave implements Visitor Leave interface. @@ -29,9 +29,9 @@ func (c *Cloner) Leave(in Node) (out Node, ok bool) { return in, true } -// cloneStruct clone a node's struct value, if the struct has slice -// the cloned value should make a new slice and copy old slice to new slice. -func cloneStruct(in Node) (out Node) { +// copyStruct copies a node's struct value, if the struct has slice member, +// make a new slice and copy old slice value to new slice. +func copyStruct(in Node) (out Node) { switch v := in.(type) { case *ValueExpr: nv := *v diff --git a/ast/parser/parser.y b/ast/parser/parser.y index f61843b28d..6997d1cfbf 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -449,7 +449,7 @@ import ( OrderBy "ORDER BY clause" ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - ByList "BY list" + ByList "BY list" OuterOpt "optional OUTER clause" QuickOptional "QUICK or empty" PasswordOpt "Password option" diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go index 29b2fa3ea3..452d047c47 100644 --- a/optimizer/aggregator.go +++ b/optimizer/aggregator.go @@ -16,10 +16,10 @@ package optimizer // Aggregator is the interface to // compute aggregate function result. type Aggregator interface { - // Input add a input value to aggregator. + // Input adds an input value to aggregator. // The input values are accumulated in the aggregator. Input(in ...interface{}) error - // Output use input values to compute the aggregated result. + // Output uses input values to compute the aggregated result. Output() interface{} // Clear clears the input values. Clear() diff --git a/optimizer/typecomputer.go b/optimizer/typecomputer.go index d1862d1380..f2cd35aff0 100644 --- a/optimizer/typecomputer.go +++ b/optimizer/typecomputer.go @@ -16,7 +16,7 @@ package optimizer import "github.com/pingcap/tidb/ast" // typeComputer is an ast Visitor that -// Compute types for ast.ExprNode. +// computes result type for ast.ExprNode. type typeComputer struct { err error } From 2f2e31c8f47c4b92d7ddd546a6534b74c6118e6d Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 5 Nov 2015 14:42:15 +0800 Subject: [PATCH 25/27] ast: rename ast node receiver 'nod' to 'n'. --- ast/ddl.go | 272 +++++++++++++-------------- ast/dml.go | 452 ++++++++++++++++++++++----------------------- ast/expressions.go | 440 +++++++++++++++++++++---------------------- ast/functions.go | 220 +++++++++++----------- ast/misc.go | 216 +++++++++++----------- 5 files changed, 800 insertions(+), 800 deletions(-) diff --git a/ast/ddl.go b/ast/ddl.go index 87a5ffba70..4f655d6f01 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -70,13 +70,13 @@ type CreateDatabaseStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateDatabaseStmt) - return v.Leave(nod) + n = newNod.(*CreateDatabaseStmt) + return v.Leave(n) } // DropDatabaseStmt is a statement to drop a database and all tables in the database. @@ -89,13 +89,13 @@ type DropDatabaseStmt struct { } // Accept implements Node Accept interface. -func (nod *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DropDatabaseStmt) - return v.Leave(nod) + n = newNod.(*DropDatabaseStmt) + return v.Leave(n) } // IndexColName is used for parsing index column name from SQL. @@ -107,18 +107,18 @@ type IndexColName struct { } // Accept implements Node Accept interface. -func (nod *IndexColName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IndexColName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IndexColName) - node, ok := nod.Column.Accept(v) + n = newNod.(*IndexColName) + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) - return v.Leave(nod) + n.Column = node.(*ColumnName) + return v.Leave(n) } // ReferenceDef is used for parsing foreign key reference option from SQL. @@ -131,25 +131,25 @@ type ReferenceDef struct { } // Accept implements Node Accept interface. -func (nod *ReferenceDef) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ReferenceDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ReferenceDef) - node, ok := nod.Table.Accept(v) + n = newNod.(*ReferenceDef) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(nod) + return v.Leave(n) } // ColumnOptionType is the type for ColumnOption. @@ -183,20 +183,20 @@ type ColumnOption struct { } // Accept implements Node Accept interface. -func (nod *ColumnOption) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnOption) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnOption) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*ColumnOption) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // ConstraintType is the type for Constraint. @@ -230,27 +230,27 @@ type Constraint struct { } // Accept implements Node Accept interface. -func (nod *Constraint) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Constraint) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Constraint) - for i, val := range nod.Keys { + n = newNod.(*Constraint) + for i, val := range n.Keys { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Keys[i] = node.(*IndexColName) + n.Keys[i] = node.(*IndexColName) } - if nod.Refer != nil { - node, ok := nod.Refer.Accept(v) + if n.Refer != nil { + node, ok := n.Refer.Accept(v) if !ok { - return nod, false + return n, false } - nod.Refer = node.(*ReferenceDef) + n.Refer = node.(*ReferenceDef) } - return v.Leave(nod) + return v.Leave(n) } // ColumnDef is used for parsing column definition from SQL. @@ -263,25 +263,25 @@ type ColumnDef struct { } // Accept implements Node Accept interface. -func (nod *ColumnDef) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnDef) - node, ok := nod.Name.Accept(v) + n = newNod.(*ColumnDef) + node, ok := n.Name.Accept(v) if !ok { - return nod, false + return n, false } - nod.Name = node.(*ColumnName) - for i, val := range nod.Options { + n.Name = node.(*ColumnName) + for i, val := range n.Options { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Options[i] = node.(*ColumnOption) + n.Options[i] = node.(*ColumnOption) } - return v.Leave(nod) + return v.Leave(n) } // CreateTableStmt is a statement to create a table. @@ -297,32 +297,32 @@ type CreateTableStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateTableStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*CreateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.Cols { + n.Table = node.(*TableName) + for i, val := range n.Cols { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Cols[i] = node.(*ColumnDef) + n.Cols[i] = node.(*ColumnDef) } - for i, val := range nod.Constraints { + for i, val := range n.Constraints { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Constraints[i] = node.(*Constraint) + n.Constraints[i] = node.(*Constraint) } - return v.Leave(nod) + return v.Leave(n) } // DropTableStmt is a statement to drop one or more tables. @@ -335,20 +335,20 @@ type DropTableStmt struct { } // Accept implements Node Accept interface. -func (nod *DropTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DropTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DropTableStmt) - for i, val := range nod.Tables { + n = newNod.(*DropTableStmt) + for i, val := range n.Tables { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - return v.Leave(nod) + return v.Leave(n) } // CreateIndexStmt is a statement to create an index. @@ -363,25 +363,25 @@ type CreateIndexStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateIndexStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*CreateIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(nod) + return v.Leave(n) } // DropIndexStmt is a statement to drop the index. @@ -395,18 +395,18 @@ type DropIndexStmt struct { } // Accept implements Node Accept interface. -func (nod *DropIndexStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DropIndexStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*DropIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - return v.Leave(nod) + n.Table = node.(*TableName) + return v.Leave(n) } // TableOptionType is the type for TableOption @@ -457,20 +457,20 @@ type ColumnPosition struct { } // Accept implements Node Accept interface. -func (nod *ColumnPosition) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnPosition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnPosition) - if nod.RelativeColumn != nil { - node, ok := nod.RelativeColumn.Accept(v) + n = newNod.(*ColumnPosition) + if n.RelativeColumn != nil { + node, ok := n.RelativeColumn.Accept(v) if !ok { - return nod, false + return n, false } - nod.RelativeColumn = node.(*ColumnName) + n.RelativeColumn = node.(*ColumnName) } - return v.Leave(nod) + return v.Leave(n) } // AlterTableType is the type for AlterTableSpec. @@ -503,41 +503,41 @@ type AlterTableSpec struct { } // Accept implements Node Accept interface. -func (nod *AlterTableSpec) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*AlterTableSpec) - if nod.Constraint != nil { - node, ok := nod.Constraint.Accept(v) + n = newNod.(*AlterTableSpec) + if n.Constraint != nil { + node, ok := n.Constraint.Accept(v) if !ok { - return nod, false + return n, false } - nod.Constraint = node.(*Constraint) + n.Constraint = node.(*Constraint) } - if nod.Column != nil { - node, ok := nod.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnDef) + n.Column = node.(*ColumnDef) } - if nod.DropColumn != nil { - node, ok := nod.DropColumn.Accept(v) + if n.DropColumn != nil { + node, ok := n.DropColumn.Accept(v) if !ok { - return nod, false + return n, false } - nod.DropColumn = node.(*ColumnName) + n.DropColumn = node.(*ColumnName) } - if nod.Position != nil { - node, ok := nod.Position.Accept(v) + if n.Position != nil { + node, ok := n.Position.Accept(v) if !ok { - return nod, false + return n, false } - nod.Position = node.(*ColumnPosition) + n.Position = node.(*ColumnPosition) } - return v.Leave(nod) + return v.Leave(n) } // AlterTableStmt is a statement to change the structure of a table. @@ -550,25 +550,25 @@ type AlterTableStmt struct { } // Accept implements Node Accept interface. -func (nod *AlterTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*AlterTableStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*AlterTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - for i, val := range nod.Specs { + n.Table = node.(*TableName) + for i, val := range n.Specs { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Specs[i] = node.(*AlterTableSpec) + n.Specs[i] = node.(*AlterTableSpec) } - return v.Leave(nod) + return v.Leave(n) } // TruncateTableStmt is a statement to empty a table completely. @@ -580,16 +580,16 @@ type TruncateTableStmt struct { } // Accept implements Node Accept interface. -func (nod *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TruncateTableStmt) - node, ok := nod.Table.Accept(v) + n = newNod.(*TruncateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) - return v.Leave(nod) + n.Table = node.(*TableName) + return v.Leave(n) } diff --git a/ast/dml.go b/ast/dml.go index 0b35a70d0e..e15b5f83bb 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -60,32 +60,32 @@ type Join struct { } // Accept implements Node Accept interface. -func (nod *Join) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Join) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Join) - node, ok := nod.Left.Accept(v) + n = newNod.(*Join) + node, ok := n.Left.Accept(v) if !ok { - return nod, false + return n, false } - nod.Left = node.(ResultSetNode) - if nod.Right != nil { - node, ok = nod.Right.Accept(v) + n.Left = node.(ResultSetNode) + if n.Right != nil { + node, ok = n.Right.Accept(v) if !ok { - return nod, false + return n, false } - nod.Right = node.(ResultSetNode) + n.Right = node.(ResultSetNode) } - if nod.On != nil { - node, ok = nod.On.Accept(v) + if n.On != nil { + node, ok = n.On.Accept(v) if !ok { - return nod, false + return n, false } - nod.On = node.(*OnCondition) + n.On = node.(*OnCondition) } - return v.Leave(nod) + return v.Leave(n) } // TableName represents a table name. @@ -101,13 +101,13 @@ type TableName struct { } // Accept implements Node Accept interface. -func (nod *TableName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TableName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TableName) - return v.Leave(nod) + n = newNod.(*TableName) + return v.Leave(n) } // TableSource represents table source with a name. @@ -123,18 +123,18 @@ type TableSource struct { } // Accept implements Node Accept interface. -func (nod *TableSource) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TableSource) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TableSource) - node, ok := nod.Source.Accept(v) + n = newNod.(*TableSource) + node, ok := n.Source.Accept(v) if !ok { - return nod, false + return n, false } - nod.Source = node.(ResultSetNode) - return v.Leave(nod) + n.Source = node.(ResultSetNode) + return v.Leave(n) } // OnCondition represetns JOIN on condition. @@ -145,28 +145,28 @@ type OnCondition struct { } // Accept implements Node Accept interface. -func (nod *OnCondition) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *OnCondition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*OnCondition) - node, ok := nod.Expr.Accept(v) + n = newNod.(*OnCondition) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // SetResultFields implements ResultSet interface. -func (nod *TableSource) SetResultFields(rfs []*ResultField) { - nod.Source.SetResultFields(rfs) +func (n *TableSource) SetResultFields(rfs []*ResultField) { + n.Source.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (nod *TableSource) GetResultFields() []*ResultField { - return nod.Source.GetResultFields() +func (n *TableSource) GetResultFields() []*ResultField { + return n.Source.GetResultFields() } // SelectLockType is the lock type for SelectStmt. @@ -188,13 +188,13 @@ type WildCardField struct { } // Accept implements Node Accept interface. -func (nod *WildCardField) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *WildCardField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*WildCardField) - return v.Leave(nod) + n = newNod.(*WildCardField) + return v.Leave(n) } // SelectField represents fields in select statement. @@ -214,20 +214,20 @@ type SelectField struct { } // Accept implements Node Accept interface. -func (nod *SelectField) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SelectField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SelectField) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*SelectField) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // FieldList represents field list in select statement. @@ -238,20 +238,20 @@ type FieldList struct { } // Accept implements Node Accept interface. -func (nod *FieldList) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FieldList) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FieldList) - for i, val := range nod.Fields { + n = newNod.(*FieldList) + for i, val := range n.Fields { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Fields[i] = node.(*SelectField) + n.Fields[i] = node.(*SelectField) } - return v.Leave(nod) + return v.Leave(n) } // TableRefsClause represents table references clause in dml statement. @@ -262,18 +262,18 @@ type TableRefsClause struct { } // Accept implements Node Accept interface. -func (nod *TableRefsClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *TableRefsClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*TableRefsClause) - node, ok := nod.TableRefs.Accept(v) + n = newNod.(*TableRefsClause) + node, ok := n.TableRefs.Accept(v) if !ok { - return nod, false + return n, false } - nod.TableRefs = node.(*Join) - return v.Leave(nod) + n.TableRefs = node.(*Join) + return v.Leave(n) } // ByItem represents an item in order by or group by. @@ -285,18 +285,18 @@ type ByItem struct { } // Accept implements Node Accept interface. -func (nod *ByItem) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ByItem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ByItem) - node, ok := nod.Expr.Accept(v) + n = newNod.(*ByItem) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // GroupByClause represents group by clause. @@ -306,20 +306,20 @@ type GroupByClause struct { } // Accept implements Node Accept interface. -func (nod *GroupByClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *GroupByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*GroupByClause) - for i, val := range nod.Items { + n = newNod.(*GroupByClause) + for i, val := range n.Items { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Items[i] = node.(*ByItem) + n.Items[i] = node.(*ByItem) } - return v.Leave(nod) + return v.Leave(n) } // HavingClause represents having clause. @@ -329,18 +329,18 @@ type HavingClause struct { } // Accept implements Node Accept interface. -func (nod *HavingClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *HavingClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*HavingClause) - node, ok := nod.Expr.Accept(v) + n = newNod.(*HavingClause) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // OrderByClause represents order by clause. @@ -351,20 +351,20 @@ type OrderByClause struct { } // Accept implements Node Accept interface. -func (nod *OrderByClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *OrderByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*OrderByClause) - for i, val := range nod.Items { + n = newNod.(*OrderByClause) + for i, val := range n.Items { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Items[i] = node.(*ByItem) + n.Items[i] = node.(*ByItem) } - return v.Leave(nod) + return v.Leave(n) } // SelectStmt represents the select query node. @@ -394,69 +394,69 @@ type SelectStmt struct { } // Accept implements Node Accept interface. -func (nod *SelectStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SelectStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SelectStmt) + n = newNod.(*SelectStmt) - if nod.From != nil { - node, ok := nod.From.Accept(v) + if n.From != nil { + node, ok := n.From.Accept(v) if !ok { - return nod, false + return n, false } - nod.From = node.(*TableRefsClause) + n.From = node.(*TableRefsClause) } - if nod.Where != nil { - node, ok := nod.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - if nod.Fields != nil { - node, ok := nod.Fields.Accept(v) + if n.Fields != nil { + node, ok := n.Fields.Accept(v) if !ok { - return nod, false + return n, false } - nod.Fields = node.(*FieldList) + n.Fields = node.(*FieldList) } - if nod.GroupBy != nil { - node, ok := nod.GroupBy.Accept(v) + if n.GroupBy != nil { + node, ok := n.GroupBy.Accept(v) if !ok { - return nod, false + return n, false } - nod.GroupBy = node.(*GroupByClause) + n.GroupBy = node.(*GroupByClause) } - if nod.Having != nil { - node, ok := nod.Having.Accept(v) + if n.Having != nil { + node, ok := n.Having.Accept(v) if !ok { - return nod, false + return n, false } - nod.Having = node.(*HavingClause) + n.Having = node.(*HavingClause) } - if nod.OrderBy != nil { - node, ok := nod.OrderBy.Accept(v) + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) if !ok { - return nod, false + return n, false } - nod.OrderBy = node.(*OrderByClause) + n.OrderBy = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok := nod.Limit.Accept(v) + if n.Limit != nil { + node, ok := n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. @@ -468,18 +468,18 @@ type UnionClause struct { } // Accept implements Node Accept interface. -func (nod *UnionClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UnionClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UnionClause) - node, ok := nod.Select.Accept(v) + n = newNod.(*UnionClause) + node, ok := n.Select.Accept(v) if !ok { - return nod, false + return n, false } - nod.Select = node.(*SelectStmt) - return v.Leave(nod) + n.Select = node.(*SelectStmt) + return v.Leave(n) } // UnionStmt represents "union statement" @@ -495,34 +495,34 @@ type UnionStmt struct { } // Accept implements Node Accept interface. -func (nod *UnionStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UnionStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UnionStmt) - for i, val := range nod.Selects { + n = newNod.(*UnionStmt) + for i, val := range n.Selects { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Selects[i] = node.(*SelectStmt) + n.Selects[i] = node.(*SelectStmt) } - if nod.OrderBy != nil { - node, ok := nod.OrderBy.Accept(v) + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) if !ok { - return nod, false + return n, false } - nod.OrderBy = node.(*OrderByClause) + n.OrderBy = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok := nod.Limit.Accept(v) + if n.Limit != nil { + node, ok := n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // Assignment is the expression for assignment, like a = 1. @@ -535,23 +535,23 @@ type Assignment struct { } // Accept implements Node Accept interface. -func (nod *Assignment) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Assignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Assignment) - node, ok := nod.Column.Accept(v) + n = newNod.(*Assignment) + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) - node, ok = nod.Expr.Accept(v) + n.Column = node.(*ColumnName) + node, ok = n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // Priority const values. @@ -579,57 +579,57 @@ type InsertStmt struct { } // Accept implements Node Accept interface. -func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *InsertStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*InsertStmt) + n = newNod.(*InsertStmt) - if nod.Select != nil { - node, ok := nod.Select.Accept(v) + if n.Select != nil { + node, ok := n.Select.Accept(v) if !ok { - return nod, false + return n, false } - nod.Select = node.(ResultSetNode) + n.Select = node.(ResultSetNode) } - node, ok := nod.Table.Accept(v) + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableRefsClause) + n.Table = node.(*TableRefsClause) - for i, val := range nod.Columns { + for i, val := range n.Columns { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Columns[i] = node.(*ColumnName) + n.Columns[i] = node.(*ColumnName) } - for i, list := range nod.Lists { + for i, list := range n.Lists { for j, val := range list { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Lists[i][j] = node.(ExprNode) + n.Lists[i][j] = node.(ExprNode) } } - for i, val := range nod.Setlist { + for i, val := range n.Setlist { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Setlist[i] = node.(*Assignment) + n.Setlist[i] = node.(*Assignment) } - for i, val := range nod.OnDuplicate { + for i, val := range n.OnDuplicate { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.OnDuplicate[i] = node.(*Assignment) + n.OnDuplicate[i] = node.(*Assignment) } - return v.Leave(nod) + return v.Leave(n) } // DeleteStmt is a statement to delete rows from table. @@ -652,49 +652,49 @@ type DeleteStmt struct { } // Accept implements Node Accept interface. -func (nod *DeleteStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DeleteStmt) + n = newNod.(*DeleteStmt) - node, ok := nod.TableRefs.Accept(v) + node, ok := n.TableRefs.Accept(v) if !ok { - return nod, false + return n, false } - nod.TableRefs = node.(*TableRefsClause) + n.TableRefs = node.(*TableRefsClause) - for i, val := range nod.Tables { + for i, val := range n.Tables { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - if nod.Where != nil { - node, ok = nod.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - if nod.Order != nil { - node, ok = nod.Order.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return nod, false + return n, false } - nod.Order = node.(*OrderByClause) + n.Order = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok = nod.Limit.Accept(v) + if n.Limit != nil { + node, ok = n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // UpdateStmt is a statement to update columns of existing rows in tables with new values. @@ -713,46 +713,46 @@ type UpdateStmt struct { } // Accept implements Node Accept interface. -func (nod *UpdateStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UpdateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UpdateStmt) - node, ok := nod.TableRefs.Accept(v) + n = newNod.(*UpdateStmt) + node, ok := n.TableRefs.Accept(v) if !ok { - return nod, false + return n, false } - nod.TableRefs = node.(*TableRefsClause) - for i, val := range nod.List { + n.TableRefs = node.(*TableRefsClause) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.List[i] = node.(*Assignment) + n.List[i] = node.(*Assignment) } - if nod.Where != nil { - node, ok = nod.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - if nod.Order != nil { - node, ok = nod.Order.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return nod, false + return n, false } - nod.Order = node.(*OrderByClause) + n.Order = node.(*OrderByClause) } - if nod.Limit != nil { - node, ok = nod.Limit.Accept(v) + if n.Limit != nil { + node, ok = n.Limit.Accept(v) if !ok { - return nod, false + return n, false } - nod.Limit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(nod) + return v.Leave(n) } // Limit is the limit clause. @@ -764,11 +764,11 @@ type Limit struct { } // Accept implements Node Accept interface. -func (nod *Limit) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *Limit) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*Limit) - return v.Leave(nod) + n = newNod.(*Limit) + return v.Leave(n) } diff --git a/ast/expressions.go b/ast/expressions.go index cddb7f45d5..c3278bd331 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -91,18 +91,18 @@ func NewValueExpr(value interface{}) *ValueExpr { } // IsStatic implements ExprNode interface. -func (nod *ValueExpr) IsStatic() bool { +func (n *ValueExpr) IsStatic() bool { return true } // Accept implements Node interface. -func (nod *ValueExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ValueExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ValueExpr) - return v.Leave(nod) + n = newNod.(*ValueExpr) + return v.Leave(n) } // BetweenExpr is for "between and" or "not between and" expression. @@ -119,37 +119,37 @@ type BetweenExpr struct { } // Accept implements Node interface. -func (nod *BetweenExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *BetweenExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*BetweenExpr) + n = newNod.(*BetweenExpr) - node, ok := nod.Expr.Accept(v) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = nod.Left.Accept(v) + node, ok = n.Left.Accept(v) if !ok { - return nod, false + return n, false } - nod.Left = node.(ExprNode) + n.Left = node.(ExprNode) - node, ok = nod.Right.Accept(v) + node, ok = n.Right.Accept(v) if !ok { - return nod, false + return n, false } - nod.Right = node.(ExprNode) + n.Right = node.(ExprNode) - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *BetweenExpr) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Left.IsStatic() && nod.Right.IsStatic() +func (n *BetweenExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Left.IsStatic() && n.Right.IsStatic() } // BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. @@ -164,31 +164,31 @@ type BinaryOperationExpr struct { } // Accept implements Node interface. -func (nod *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*BinaryOperationExpr) + n = newNod.(*BinaryOperationExpr) - node, ok := nod.L.Accept(v) + node, ok := n.L.Accept(v) if !ok { - return nod, false + return n, false } - nod.L = node.(ExprNode) + n.L = node.(ExprNode) - node, ok = nod.R.Accept(v) + node, ok = n.R.Accept(v) if !ok { - return nod, false + return n, false } - nod.R = node.(ExprNode) + n.R = node.(ExprNode) - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *BinaryOperationExpr) IsStatic() bool { - return nod.L.IsStatic() && nod.R.IsStatic() +func (n *BinaryOperationExpr) IsStatic() bool { + return n.L.IsStatic() && n.R.IsStatic() } // WhenClause is the when clause in Case expression for "when condition then result". @@ -201,29 +201,29 @@ type WhenClause struct { } // Accept implements Node Accept interface. -func (nod *WhenClause) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *WhenClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*WhenClause) - node, ok := nod.Expr.Accept(v) + n = newNod.(*WhenClause) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = nod.Result.Accept(v) + node, ok = n.Result.Accept(v) if !ok { - return nod, false + return n, false } - nod.Result = node.(ExprNode) - return v.Leave(nod) + n.Result = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *WhenClause) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Result.IsStatic() +func (n *WhenClause) IsStatic() bool { + return n.Expr.IsStatic() && n.Result.IsStatic() } // CaseExpr is the case expression. @@ -238,47 +238,47 @@ type CaseExpr struct { } // Accept implements Node Accept interface. -func (nod *CaseExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CaseExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CaseExpr) - if nod.Value != nil { - node, ok := nod.Value.Accept(v) + n = newNod.(*CaseExpr) + if n.Value != nil { + node, ok := n.Value.Accept(v) if !ok { - return nod, false + return n, false } - nod.Value = node.(ExprNode) + n.Value = node.(ExprNode) } - for i, val := range nod.WhenClauses { + for i, val := range n.WhenClauses { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.WhenClauses[i] = node.(*WhenClause) + n.WhenClauses[i] = node.(*WhenClause) } - if nod.ElseClause != nil { - node, ok := nod.ElseClause.Accept(v) + if n.ElseClause != nil { + node, ok := n.ElseClause.Accept(v) if !ok { - return nod, false + return n, false } - nod.ElseClause = node.(ExprNode) + n.ElseClause = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *CaseExpr) IsStatic() bool { - if nod.Value != nil && !nod.Value.IsStatic() { +func (n *CaseExpr) IsStatic() bool { + if n.Value != nil && !n.Value.IsStatic() { return false } - for _, w := range nod.WhenClauses { + for _, w := range n.WhenClauses { if !w.IsStatic() { return false } } - if nod.ElseClause != nil && !nod.ElseClause.IsStatic() { + if n.ElseClause != nil && !n.ElseClause.IsStatic() { return false } return true @@ -292,28 +292,28 @@ type SubqueryExpr struct { } // Accept implements Node Accept interface. -func (nod *SubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SubqueryExpr) - node, ok := nod.Query.Accept(v) + n = newNod.(*SubqueryExpr) + node, ok := n.Query.Accept(v) if !ok { - return nod, false + return n, false } - nod.Query = node.(ResultSetNode) - return v.Leave(nod) + n.Query = node.(ResultSetNode) + return v.Leave(n) } // SetResultFields implements ResultSet interface. -func (nod *SubqueryExpr) SetResultFields(rfs []*ResultField) { - nod.Query.SetResultFields(rfs) +func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) { + n.Query.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (nod *SubqueryExpr) GetResultFields() []*ResultField { - return nod.Query.GetResultFields() +func (n *SubqueryExpr) GetResultFields() []*ResultField { + return n.Query.GetResultFields() } // CompareSubqueryExpr is the expression for "expr cmp (select ...)". @@ -333,23 +333,23 @@ type CompareSubqueryExpr struct { } // Accept implements Node Accept interface. -func (nod *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CompareSubqueryExpr) - node, ok := nod.L.Accept(v) + n = newNod.(*CompareSubqueryExpr) + node, ok := n.L.Accept(v) if !ok { - return nod, false + return n, false } - nod.L = node.(ExprNode) - node, ok = nod.R.Accept(v) + n.L = node.(ExprNode) + node, ok = n.R.Accept(v) if !ok { - return nod, false + return n, false } - nod.R = node.(*SubqueryExpr) - return v.Leave(nod) + n.R = node.(*SubqueryExpr) + return v.Leave(n) } // ColumnName represents column name. @@ -365,13 +365,13 @@ type ColumnName struct { } // Accept implements Node Accept interface. -func (nod *ColumnName) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnName) - return v.Leave(nod) + n = newNod.(*ColumnName) + return v.Leave(n) } // ColumnNameExpr represents a column name expression. @@ -383,18 +383,18 @@ type ColumnNameExpr struct { } // Accept implements Node Accept interface. -func (nod *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ColumnNameExpr) - node, ok := nod.Name.Accept(v) + n = newNod.(*ColumnNameExpr) + node, ok := n.Name.Accept(v) if !ok { - return nod, false + return n, false } - nod.Name = node.(*ColumnName) - return v.Leave(nod) + n.Name = node.(*ColumnName) + return v.Leave(n) } // DefaultExpr is the default expression using default value for a column. @@ -405,20 +405,20 @@ type DefaultExpr struct { } // Accept implements Node Accept interface. -func (nod *DefaultExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DefaultExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DefaultExpr) - if nod.Name != nil { - node, ok := nod.Name.Accept(v) + n = newNod.(*DefaultExpr) + if n.Name != nil { + node, ok := n.Name.Accept(v) if !ok { - return nod, false + return n, false } - nod.Name = node.(*ColumnName) + n.Name = node.(*ColumnName) } - return v.Leave(nod) + return v.Leave(n) } // IdentifierExpr represents an identifier expression. @@ -429,13 +429,13 @@ type IdentifierExpr struct { } // Accept implements Node Accept interface. -func (nod *IdentifierExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IdentifierExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IdentifierExpr) - return v.Leave(nod) + n = newNod.(*IdentifierExpr) + return v.Leave(n) } // ExistsSubqueryExpr is the expression for "exists (select ...)". @@ -447,18 +447,18 @@ type ExistsSubqueryExpr struct { } // Accept implements Node Accept interface. -func (nod *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ExistsSubqueryExpr) - node, ok := nod.Sel.Accept(v) + n = newNod.(*ExistsSubqueryExpr) + node, ok := n.Sel.Accept(v) if !ok { - return nod, false + return n, false } - nod.Sel = node.(*SubqueryExpr) - return v.Leave(nod) + n.Sel = node.(*SubqueryExpr) + return v.Leave(n) } // PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)". @@ -475,32 +475,32 @@ type PatternInExpr struct { } // Accept implements Node Accept interface. -func (nod *PatternInExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PatternInExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PatternInExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*PatternInExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - for i, val := range nod.List { + n.Expr = node.(ExprNode) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return nod, false + return n, false } - nod.List[i] = node.(ExprNode) + n.List[i] = node.(ExprNode) } - if nod.Sel != nil { - node, ok = nod.Sel.Accept(v) + if n.Sel != nil { + node, ok = n.Sel.Accept(v) if !ok { - return nod, false + return n, false } - nod.Sel = node.(*SubqueryExpr) + n.Sel = node.(*SubqueryExpr) } - return v.Leave(nod) + return v.Leave(n) } // IsNullExpr is the expression for null check. @@ -513,23 +513,23 @@ type IsNullExpr struct { } // Accept implements Node Accept interface. -func (nod *IsNullExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IsNullExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IsNullExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*IsNullExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *IsNullExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *IsNullExpr) IsStatic() bool { + return n.Expr.IsStatic() } // IsTruthExpr is the expression for true/false check. @@ -544,23 +544,23 @@ type IsTruthExpr struct { } // Accept implements Node Accept interface. -func (nod *IsTruthExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*IsTruthExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*IsTruthExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *IsTruthExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *IsTruthExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PatternLikeExpr is the expression for like operator, e.g, expr like "%123%" @@ -577,32 +577,32 @@ type PatternLikeExpr struct { } // Accept implements Node Accept interface. -func (nod *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PatternLikeExpr) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*PatternLikeExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - if nod.Pattern != nil { - node, ok := nod.Pattern.Accept(v) + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pattern = node.(ExprNode) + n.Pattern = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *PatternLikeExpr) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Pattern.IsStatic() +func (n *PatternLikeExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // ParamMarkerExpr expresion holds a place for another expression. @@ -613,13 +613,13 @@ type ParamMarkerExpr struct { } // Accept implements Node Accept interface. -func (nod *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ParamMarkerExpr) - return v.Leave(nod) + n = newNod.(*ParamMarkerExpr) + return v.Leave(n) } // ParenthesesExpr is the parentheses expression. @@ -630,25 +630,25 @@ type ParenthesesExpr struct { } // Accept implements Node Accept interface. -func (nod *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ParenthesesExpr) - if nod.Expr != nil { - node, ok := nod.Expr.Accept(v) + n = newNod.(*ParenthesesExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *ParenthesesExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *ParenthesesExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PositionExpr is the expression for order by and group by position. @@ -663,18 +663,18 @@ type PositionExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (nod *PositionExpr) IsStatic() bool { +func (n *PositionExpr) IsStatic() bool { return true } // Accept implements Node Accept interface. -func (nod *PositionExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PositionExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PositionExpr) - return v.Leave(nod) + n = newNod.(*PositionExpr) + return v.Leave(n) } // PatternRegexpExpr is the pattern expression for pattern match. @@ -689,28 +689,28 @@ type PatternRegexpExpr struct { } // Accept implements Node Accept interface. -func (nod *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PatternRegexpExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*PatternRegexpExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - node, ok = nod.Pattern.Accept(v) + n.Expr = node.(ExprNode) + node, ok = n.Pattern.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pattern = node.(ExprNode) - return v.Leave(nod) + n.Pattern = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *PatternRegexpExpr) IsStatic() bool { - return nod.Expr.IsStatic() && nod.Pattern.IsStatic() +func (n *PatternRegexpExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // RowExpr is the expression for row constructor. @@ -722,25 +722,25 @@ type RowExpr struct { } // Accept implements Node Accept interface. -func (nod *RowExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *RowExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*RowExpr) - for i, val := range nod.Values { + n = newNod.(*RowExpr) + for i, val := range n.Values { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Values[i] = node.(ExprNode) + n.Values[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *RowExpr) IsStatic() bool { - for _, v := range nod.Values { +func (n *RowExpr) IsStatic() bool { + for _, v := range n.Values { if !v.IsStatic() { return false } @@ -758,23 +758,23 @@ type UnaryOperationExpr struct { } // Accept implements Node Accept interface. -func (nod *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UnaryOperationExpr) - node, ok := nod.V.Accept(v) + n = newNod.(*UnaryOperationExpr) + node, ok := n.V.Accept(v) if !ok { - return nod, false + return n, false } - nod.V = node.(ExprNode) - return v.Leave(nod) + n.V = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *UnaryOperationExpr) IsStatic() bool { - return nod.V.IsStatic() +func (n *UnaryOperationExpr) IsStatic() bool { + return n.V.IsStatic() } // ValuesExpr is the expression used in INSERT VALUES @@ -785,18 +785,18 @@ type ValuesExpr struct { } // Accept implements Node Accept interface. -func (nod *ValuesExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ValuesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ValuesExpr) - node, ok := nod.Column.Accept(v) + n = newNod.(*ValuesExpr) + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) - return v.Leave(nod) + n.Column = node.(*ColumnName) + return v.Leave(n) } // VariableExpr is the expression for variable. @@ -811,11 +811,11 @@ type VariableExpr struct { } // Accept implements Node Accept interface. -func (nod *VariableExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *VariableExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*VariableExpr) - return v.Leave(nod) + n = newNod.(*VariableExpr) + return v.Leave(n) } diff --git a/ast/functions.go b/ast/functions.go index 8d688f9eee..7faab4c6e1 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -45,30 +45,30 @@ type FuncCallExpr struct { } // Accept implements Node interface. -func (nod *FuncCallExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncCallExpr) - for i, val := range nod.Args { + n = newNod.(*FuncCallExpr) + for i, val := range n.Args { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Args[i] = node.(ExprNode) + n.Args[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(nod.FnName)] +func (n *FuncCallExpr) IsStatic() bool { + v := builtin.Funcs[strings.ToLower(n.FnName)] if v.F == nil || !v.IsStatic { return false } - for _, v := range nod.Args { + for _, v := range n.Args { if !v.IsStatic() { return false } @@ -86,23 +86,23 @@ type FuncExtractExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncExtractExpr) - node, ok := nod.Date.Accept(v) + n = newNod.(*FuncExtractExpr) + node, ok := n.Date.Accept(v) if !ok { - return nod, false + return n, false } - nod.Date = node.(ExprNode) - return v.Leave(nod) + n.Date = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncExtractExpr) IsStatic() bool { - return nod.Date.IsStatic() +func (n *FuncExtractExpr) IsStatic() bool { + return n.Date.IsStatic() } // FuncConvertExpr provides a way to convert data between different character sets. @@ -116,23 +116,23 @@ type FuncConvertExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncConvertExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *FuncConvertExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (nod *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncConvertExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*FuncConvertExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // CastFunctionType is the type for cast function. @@ -158,23 +158,23 @@ type FuncCastExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncCastExpr) IsStatic() bool { - return nod.Expr.IsStatic() +func (n *FuncCastExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (nod *FuncCastExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncCastExpr) - node, ok := nod.Expr.Accept(v) + n = newNod.(*FuncCastExpr) + node, ok := n.Expr.Accept(v) if !ok { - return nod, false + return n, false } - nod.Expr = node.(ExprNode) - return v.Leave(nod) + n.Expr = node.(ExprNode) + return v.Leave(n) } // FuncSubstringExpr returns the substring as specified. @@ -188,35 +188,35 @@ type FuncSubstringExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncSubstringExpr) - node, ok := nod.StrExpr.Accept(v) + n = newNod.(*FuncSubstringExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return nod, false + return n, false } - nod.StrExpr = node.(ExprNode) - node, ok = nod.Pos.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pos = node.(ExprNode) - if nod.Len != nil { - node, ok = nod.Len.Accept(v) + n.Pos = node.(ExprNode) + if n.Len != nil { + node, ok = n.Len.Accept(v) if !ok { - return nod, false + return n, false } - nod.Len = node.(ExprNode) + n.Len = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncSubstringExpr) IsStatic() bool { - return nod.StrExpr.IsStatic() && nod.Pos.IsStatic() && nod.Len.IsStatic() +func (n *FuncSubstringExpr) IsStatic() bool { + return n.StrExpr.IsStatic() && n.Pos.IsStatic() && n.Len.IsStatic() } // FuncSubstringIndexExpr returns the substring as specified. @@ -230,28 +230,28 @@ type FuncSubstringIndexExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncSubstringIndexExpr) - node, ok := nod.StrExpr.Accept(v) + n = newNod.(*FuncSubstringIndexExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return nod, false + return n, false } - nod.StrExpr = node.(ExprNode) - node, ok = nod.Delim.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Delim.Accept(v) if !ok { - return nod, false + return n, false } - nod.Delim = node.(ExprNode) - node, ok = nod.Count.Accept(v) + n.Delim = node.(ExprNode) + node, ok = n.Count.Accept(v) if !ok { - return nod, false + return n, false } - nod.Count = node.(ExprNode) - return v.Leave(nod) + n.Count = node.(ExprNode) + return v.Leave(n) } // FuncLocateExpr returns the position of the first occurrence of substring. @@ -265,28 +265,28 @@ type FuncLocateExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncLocateExpr) - node, ok := nod.Str.Accept(v) + n = newNod.(*FuncLocateExpr) + node, ok := n.Str.Accept(v) if !ok { - return nod, false + return n, false } - nod.Str = node.(ExprNode) - node, ok = nod.SubStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.SubStr.Accept(v) if !ok { - return nod, false + return n, false } - nod.SubStr = node.(ExprNode) - node, ok = nod.Pos.Accept(v) + n.SubStr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pos = node.(ExprNode) - return v.Leave(nod) + n.Pos = node.(ExprNode) + return v.Leave(n) } // TrimDirectionType is the type for trim direction. @@ -314,28 +314,28 @@ type FuncTrimExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncTrimExpr) - node, ok := nod.Str.Accept(v) + n = newNod.(*FuncTrimExpr) + node, ok := n.Str.Accept(v) if !ok { - return nod, false + return n, false } - nod.Str = node.(ExprNode) - node, ok = nod.RemStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.RemStr.Accept(v) if !ok { - return nod, false + return n, false } - nod.RemStr = node.(ExprNode) - return v.Leave(nod) + n.RemStr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (nod *FuncTrimExpr) IsStatic() bool { - return nod.Str.IsStatic() && nod.RemStr.IsStatic() +func (n *FuncTrimExpr) IsStatic() bool { + return n.Str.IsStatic() && n.RemStr.IsStatic() } // DateArithType is type for DateArith option. @@ -361,27 +361,27 @@ type FuncDateArithExpr struct { } // Accept implements Node Accept interface. -func (nod *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*FuncDateArithExpr) - if nod.Date != nil { - node, ok := nod.Date.Accept(v) + n = newNod.(*FuncDateArithExpr) + if n.Date != nil { + node, ok := n.Date.Accept(v) if !ok { - return nod, false + return n, false } - nod.Date = node.(ExprNode) + n.Date = node.(ExprNode) } - if nod.Interval != nil { - node, ok := nod.Date.Accept(v) + if n.Interval != nil { + node, ok := n.Date.Accept(v) if !ok { - return nod, false + return n, false } - nod.Date = node.(ExprNode) + n.Date = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // AggregateFuncExpr represents aggregate function expression. @@ -398,18 +398,18 @@ type AggregateFuncExpr struct { } // Accept implements Node Accept interface. -func (nod *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*AggregateFuncExpr) - for i, val := range nod.Args { + n = newNod.(*AggregateFuncExpr) + for i, val := range n.Args { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Args[i] = node.(ExprNode) + n.Args[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } diff --git a/ast/misc.go b/ast/misc.go index b0f9e845fc..4f287b4680 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -61,18 +61,18 @@ type ExplainStmt struct { } // Accept implements Node Accept interface. -func (nod *ExplainStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ExplainStmt) - node, ok := nod.Stmt.Accept(v) + n = newNod.(*ExplainStmt) + node, ok := n.Stmt.Accept(v) if !ok { - return nod, false + return n, false } - nod.Stmt = node.(DMLNode) - return v.Leave(nod) + n.Stmt = node.(DMLNode) + return v.Leave(n) } // PrepareStmt is a statement to prepares a SQL statement which contains placeholders, @@ -89,18 +89,18 @@ type PrepareStmt struct { } // Accept implements Node Accept interface. -func (nod *PrepareStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PrepareStmt) - node, ok := nod.SQLVar.Accept(v) + n = newNod.(*PrepareStmt) + node, ok := n.SQLVar.Accept(v) if !ok { - return nod, false + return n, false } - nod.SQLVar = node.(*VariableExpr) - return v.Leave(nod) + n.SQLVar = node.(*VariableExpr) + return v.Leave(n) } // DeallocateStmt is a statement to release PreparedStmt. @@ -113,13 +113,13 @@ type DeallocateStmt struct { } // Accept implements Node Accept interface. -func (nod *DeallocateStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DeallocateStmt) - return v.Leave(nod) + n = newNod.(*DeallocateStmt) + return v.Leave(n) } // ExecuteStmt is a statement to execute PreparedStmt. @@ -133,20 +133,20 @@ type ExecuteStmt struct { } // Accept implements Node Accept interface. -func (nod *ExecuteStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ExecuteStmt) - for i, val := range nod.UsingVars { + n = newNod.(*ExecuteStmt) + for i, val := range n.UsingVars { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.UsingVars[i] = node.(ExprNode) + n.UsingVars[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // ShowStmtType is the type for SHOW statement. @@ -187,41 +187,41 @@ type ShowStmt struct { } // Accept implements Node Accept interface. -func (nod *ShowStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *ShowStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*ShowStmt) - if nod.Table != nil { - node, ok := nod.Table.Accept(v) + n = newNod.(*ShowStmt) + if n.Table != nil { + node, ok := n.Table.Accept(v) if !ok { - return nod, false + return n, false } - nod.Table = node.(*TableName) + n.Table = node.(*TableName) } - if nod.Column != nil { - node, ok := nod.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return nod, false + return n, false } - nod.Column = node.(*ColumnName) + n.Column = node.(*ColumnName) } - if nod.Pattern != nil { - node, ok := nod.Pattern.Accept(v) + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) if !ok { - return nod, false + return n, false } - nod.Pattern = node.(*PatternLikeExpr) + n.Pattern = node.(*PatternLikeExpr) } - if nod.Where != nil { - node, ok := nod.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return nod, false + return n, false } - nod.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // BeginStmt is a statement to start a new transaction. @@ -231,13 +231,13 @@ type BeginStmt struct { } // Accept implements Node Accept interface. -func (nod *BeginStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *BeginStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*BeginStmt) - return v.Leave(nod) + n = newNod.(*BeginStmt) + return v.Leave(n) } // CommitStmt is a statement to commit the current transaction. @@ -247,13 +247,13 @@ type CommitStmt struct { } // Accept implements Node Accept interface. -func (nod *CommitStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CommitStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CommitStmt) - return v.Leave(nod) + n = newNod.(*CommitStmt) + return v.Leave(n) } // RollbackStmt is a statement to roll back the current transaction. @@ -263,13 +263,13 @@ type RollbackStmt struct { } // Accept implements Node Accept interface. -func (nod *RollbackStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*RollbackStmt) - return v.Leave(nod) + n = newNod.(*RollbackStmt) + return v.Leave(n) } // UseStmt is a statement to use the DBName database as the current database. @@ -281,13 +281,13 @@ type UseStmt struct { } // Accept implements Node Accept interface. -func (nod *UseStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *UseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*UseStmt) - return v.Leave(nod) + n = newNod.(*UseStmt) + return v.Leave(n) } // VariableAssignment is a variable assignment struct. @@ -300,18 +300,18 @@ type VariableAssignment struct { } // Accept implements Node interface. -func (nod *VariableAssignment) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*VariableAssignment) - node, ok := nod.Value.Accept(v) + n = newNod.(*VariableAssignment) + node, ok := n.Value.Accept(v) if !ok { - return nod, false + return n, false } - nod.Value = node.(ExprNode) - return v.Leave(nod) + n.Value = node.(ExprNode) + return v.Leave(n) } // SetStmt is the statement to set variables. @@ -322,20 +322,20 @@ type SetStmt struct { } // Accept implements Node Accept interface. -func (nod *SetStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SetStmt) - for i, val := range nod.Variables { + n = newNod.(*SetStmt) + for i, val := range n.Variables { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Variables[i] = node.(*VariableAssignment) + n.Variables[i] = node.(*VariableAssignment) } - return v.Leave(nod) + return v.Leave(n) } // SetCharsetStmt is a statement to assign values to character and collation variables. @@ -348,13 +348,13 @@ type SetCharsetStmt struct { } // Accept implements Node Accept interface. -func (nod *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SetCharsetStmt) - return v.Leave(nod) + n = newNod.(*SetCharsetStmt) + return v.Leave(n) } // SetPwdStmt is a statement to assign a password to user account. @@ -367,13 +367,13 @@ type SetPwdStmt struct { } // Accept implements Node Accept interface. -func (nod *SetPwdStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*SetPwdStmt) - return v.Leave(nod) + n = newNod.(*SetPwdStmt) + return v.Leave(n) } // UserSpec is used for parsing create user statement. @@ -392,13 +392,13 @@ type CreateUserStmt struct { } // Accept implements Node Accept interface. -func (nod *CreateUserStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*CreateUserStmt) - return v.Leave(nod) + n = newNod.(*CreateUserStmt) + return v.Leave(n) } // DoStmt is the struct for DO statement. @@ -409,20 +409,20 @@ type DoStmt struct { } // Accept implements Node Accept interface. -func (nod *DoStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *DoStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*DoStmt) - for i, val := range nod.Exprs { + n = newNod.(*DoStmt) + for i, val := range n.Exprs { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Exprs[i] = node.(ExprNode) + n.Exprs[i] = node.(ExprNode) } - return v.Leave(nod) + return v.Leave(n) } // PrivElem is the privilege type and optional column list. @@ -433,20 +433,20 @@ type PrivElem struct { } // Accept implements Node Accept interface. -func (nod *PrivElem) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *PrivElem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*PrivElem) - for i, val := range nod.Cols { + n = newNod.(*PrivElem) + for i, val := range n.Cols { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Cols[i] = node.(*ColumnName) + n.Cols[i] = node.(*ColumnName) } - return v.Leave(nod) + return v.Leave(n) } // ObjectTypeType is the type for object type. @@ -491,18 +491,18 @@ type GrantStmt struct { } // Accept implements Node Accept interface. -func (nod *GrantStmt) Accept(v Visitor) (Node, bool) { - newNod, skipChildren := v.Enter(nod) +func (n *GrantStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) if skipChildren { return v.Leave(newNod) } - nod = newNod.(*GrantStmt) - for i, val := range nod.Privs { + n = newNod.(*GrantStmt) + for i, val := range n.Privs { node, ok := val.Accept(v) if !ok { - return nod, false + return n, false } - nod.Privs[i] = node.(*PrivElem) + n.Privs[i] = node.(*PrivElem) } - return v.Leave(nod) + return v.Leave(n) } From b24f320d559474c5cd37fcc1ddf5d12a1c0703c7 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 5 Nov 2015 16:03:26 +0800 Subject: [PATCH 26/27] parser: remove old parser and move ast/parser to /parser. --- Makefile | 22 +- ast/parser/parser.y | 4283 ---------------------------------- ast/parser/parser_test.go | 647 ----- ast/parser/scanner.l | 1138 --------- ast/parser/yy_parser.go | 19 - ddl/ddl_test.go | 2 +- optimizer/infobinder_test.go | 2 +- parser/parser.y | 2147 +++++++---------- parser/parser_test.go | 56 +- parser/scanner.l | 49 +- parser/scanner_test.go | 92 - tidb.go | 2 +- 12 files changed, 951 insertions(+), 7508 deletions(-) delete mode 100644 ast/parser/parser.y delete mode 100644 ast/parser/parser_test.go delete mode 100644 ast/parser/scanner.l delete mode 100644 ast/parser/yy_parser.go delete mode 100644 parser/scanner_test.go diff --git a/Makefile b/Makefile index 871e02ac7b..3468b1758c 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ TARGET = "" .PHONY: godep deps all build install parser clean todo test gotest interpreter server -all: godep parser ast-parser build test check +all: godep parser build test check godep: go get github.com/tools/godep @@ -31,7 +31,7 @@ parser: goyacc -o /dev/null -xegen $$a parser/parser.y; \ goyacc -o parser/parser.go -xe $$a parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \ rm -f $$a; \ - rm -f y.output + rm -f y.output @if [ $(ARCH) = $(LINUX) ]; \ then \ @@ -44,24 +44,6 @@ parser: golex -o parser/scanner.go parser/scanner.l -ast-parser: - a=`mktemp temp.XXXXXX`; \ - goyacc -o /dev/null -xegen $$a ast/parser/parser.y; \ - goyacc -o ast/parser/parser.go -xe $$a ast/parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \ - rm -f $$a; \ - rm -f y.output - - @if [ $(ARCH) = $(LINUX) ]; \ - then \ - sed -i -e 's|//line.*||' -e 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \ - elif [ $(ARCH) = $(MAC) ]; \ - then \ - /usr/bin/sed -i "" 's|//line.*||' ast/parser/parser.go; \ - /usr/bin/sed -i "" 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \ - fi - - golex -o ast/parser/scanner.go ast/parser/scanner.l - check: go get github.com/golang/lint/golint diff --git a/ast/parser/parser.y b/ast/parser/parser.y deleted file mode 100644 index 1393a7dff1..0000000000 --- a/ast/parser/parser.y +++ /dev/null @@ -1,4283 +0,0 @@ -%{ -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -// Inital yacc source generated by ebnf2y[1] -// at 2013-10-04 23:10:47.861401015 +0200 CEST -// -// $ ebnf2y -o ql.y -oe ql.ebnf -start StatementList -pkg ql -p _ -// -// [1]: http://github.com/cznic/ebnf2y - -package parser - -import ( - "strings" - - "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/field" - "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/util/charset" - "github.com/pingcap/tidb/util/types" -) - -%} - -%union { - offset int // offset - line int - col int - item interface{} - list []interface{} -} - -%token - - /*yy:token "1.%d" */ floatLit "floating-point literal" - /*yy:token "%c" */ identifier "identifier" - /*yy:token "%d" */ intLit "integer literal" - /*yy:token "\"%c\"" */ stringLit "string literal" - /*yy:token "%x" */ hexLit "hexadecimal literal" - /*yy:token "%b" */ bitLit "bit literal" - - - abs "ABS" - add "ADD" - after "AFTER" - all "ALL" - alter "ALTER" - and "AND" - andand "&&" - andnot "&^" - any "ANY" - as "AS" - asc "ASC" - at "AT" - autoIncrement "AUTO_INCREMENT" - avg "AVG" - avgRowLength "AVG_ROW_LENGTH" - begin "BEGIN" - between "BETWEEN" - both "BOTH" - by "BY" - byteType "BYTE" - caseKwd "CASE" - cast "CAST" - character "CHARACTER" - charsetKwd "CHARSET" - check "CHECK" - checksum "CHECKSUM" - coalesce "COALESCE" - collate "COLLATE" - collation "COLLATION" - column "COLUMN" - columns "COLUMNS" - comment "COMMENT" - commit "COMMIT" - compression "COMPRESSION" - concat "CONCAT" - concatWs "CONCAT_WS" - connection "CONNECTION" - constraint "CONSTRAINT" - convert "CONVERT" - count "COUNT" - create "CREATE" - cross "CROSS" - curDate "CURDATE" - currentDate "CURRENT_DATE" - currentUser "CURRENT_USER" - database "DATABASE" - databases "DATABASES" - dateAdd "DATE_ADD" - dateSub "DATE_SUB" - day "DAY" - dayofmonth "DAYOFMONTH" - dayofweek "DAYOFWEEK" - dayofyear "DAYOFYEAR" - deallocate "DEALLOCATE" - defaultKwd "DEFAULT" - delayed "DELAYED" - deleteKwd "DELETE" - desc "DESC" - describe "DESCRIBE" - distinct "DISTINCT" - div "DIV" - do "DO" - drop "DROP" - dual "DUAL" - duplicate "DUPLICATE" - elseKwd "ELSE" - end "END" - engine "ENGINE" - engines "ENGINES" - enum "ENUM" - eq "=" - escape "ESCAPE" - execute "EXECUTE" - exists "EXISTS" - explain "EXPLAIN" - extract "EXTRACT" - falseKwd "false" - first "FIRST" - foreign "FOREIGN" - forKwd "FOR" - foundRows "FOUND_ROWS" - from "FROM" - full "FULL" - fulltext "FULLTEXT" - ge ">=" - global "GLOBAL" - grant "GRANT" - grants "GRANTS" - group "GROUP" - groupConcat "GROUP_CONCAT" - having "HAVING" - highPriority "HIGH_PRIORITY" - hour "HOUR" - identified "IDENTIFIED" - ignore "IGNORE" - ifKwd "IF" - ifNull "IFNULL" - in "IN" - index "INDEX" - inner "INNER" - insert "INSERT" - interval "INTERVAL" - into "INTO" - is "IS" - join "JOIN" - key "KEY" - keyBlockSize "KEY_BLOCK_SIZE" - le "<=" - leading "LEADING" - left "LEFT" - length "LENGTH" - like "LIKE" - limit "LIMIT" - local "LOCAL" - locate "LOCATE" - lock "LOCK" - lower "LOWER" - lowPriority "LOW_PRIORITY" - lsh "<<" - max "MAX" - maxRows "MAX_ROWS" - microsecond "MICROSECOND" - min "MIN" - minute "MINUTE" - minRows "MIN_ROWS" - mod "MOD" - mode "MODE" - month "MONTH" - names "NAMES" - national "NATIONAL" - neq "!=" - neqSynonym "<>" - not "NOT" - null "NULL" - nulleq "<=>" - nullIf "NULLIF" - offset "OFFSET" - on "ON" - option "OPTION" - or "OR" - order "ORDER" - oror "||" - outer "OUTER" - password "PASSWORD" - placeholder "PLACEHOLDER" - prepare "PREPARE" - primary "PRIMARY" - quarter "QUARTER" - quick "QUICK" - rand "RAND" - references "REFERENCES" - regexp "REGEXP" - repeat "REPEAT" - replace "REPLACE" - right "RIGHT" - rlike "RLIKE" - rollback "ROLLBACK" - row "ROW" - rsh ">>" - schema "SCHEMA" - schemas "SCHEMAS" - second "SECOND" - selectKwd "SELECT" - session "SESSION" - set "SET" - share "SHARE" - show "SHOW" - signed "SIGNED" - some "SOME" - start "START" - stringType "string" - substring "SUBSTRING" - substringIndex "SUBSTRING_INDEX" - sum "SUM" - sysVar "SYS_VAR" - sysDate "SYSDATE" - tableKwd "TABLE" - tables "TABLES" - then "THEN" - to "TO" - trailing "TRAILING" - transaction "TRANSACTION" - trim "TRIM" - trueKwd "true" - truncate "TRUNCATE" - underscoreCS "UNDERSCORE_CHARSET" - unknown "UNKNOWN" - union "UNION" - unique "UNIQUE" - unsigned "UNSIGNED" - update "UPDATE" - upper "UPPER" - use "USE" - user "USER" - using "USING" - userVar "USER_VAR" - value "VALUE" - values "VALUES" - variables "VARIABLES" - warnings "WARNINGS" - week "WEEK" - weekday "WEEKDAY" - weekofyear "WEEKOFYEAR" - when "WHEN" - where "WHERE" - xor "XOR" - yearweek "YEARWEEK" - zerofill "ZEROFILL" - - calcFoundRows "SQL_CALC_FOUND_ROWS" - - currentTs "CURRENT_TIMESTAMP" - localTime "LOCALTIME" - localTs "LOCALTIMESTAMP" - now "NOW" - - tinyIntType "TINYINT" - smallIntType "SMALLINT" - mediumIntType "MEDIUMINT" - intType "INT" - integerType "INTEGER" - bigIntType "BIGINT" - bitType "BIT" - - decimalType "DECIMAL" - numericType "NUMERIC" - floatType "float" - doubleType "DOUBLE" - precisionType "PRECISION" - realType "REAL" - - dateType "DATE" - timeType "TIME" - datetimeType "DATETIME" - timestampType "TIMESTAMP" - yearType "YEAR" - - charType "CHAR" - varcharType "VARCHAR" - binaryType "BINARY" - varbinaryType "VARBINARY" - tinyblobType "TINYBLOB" - blobType "BLOB" - mediumblobType "MEDIUMBLOB" - longblobType "LONGBLOB" - tinytextType "TINYTEXT" - textType "TEXT" - mediumtextType "MEDIUMTEXT" - longtextType "LONGTEXT" - - int16Type "int16" - int24Type "int24" - int32Type "int32" - int64Type "int64" - int8Type "int8" - uintType "uint" - uint16Type "uint16" - uint32Type "uint32" - uint64Type "uint64" - uint8Type "uint8", - float32Type "float32" - float64Type "float64" - boolType "BOOL" - booleanType "BOOLEAN" - - parseExpression "parse expression prefix" - - secondMicrosecond "SECOND_MICROSECOND" - minuteMicrosecond "MINUTE_MICROSECOND" - minuteSecond "MINUTE_SECOND" - hourMicrosecond "HOUR_MICROSECOND" - hourSecond "HOUR_SECOND" - hourMinute "HOUR_MINUTE" - dayMicrosecond "DAY_MICROSECOND" - daySecond "DAY_SECOND" - dayMinute "DAY_MINUTE" - dayHour "DAY_HOUR" - yearMonth "YEAR_MONTH" - -%type - AlterTableStmt "Alter table statement" - AlterTableSpec "Alter table specification" - AlterTableSpecList "Alter table specification list" - AnyOrAll "Any or All for subquery" - Assignment "assignment" - AssignmentList "assignment list" - AssignmentListOpt "assignment list opt" - AuthOption "User auth option" - AuthString "Password string value" - BeginTransactionStmt "BEGIN TRANSACTION statement" - CastType "Cast function target type" - ColumnDef "table column definition" - ColumnName "column name" - ColumnNameList "column name list" - ColumnNameListOpt "column name list opt" - ColumnKeywordOpt "Column keyword or empty" - ColumnSetValue "insert statement set value by column name" - ColumnSetValueList "insert statement set value by column name list" - CommaOpt "optional comma" - CommitStmt "COMMIT statement" - CompareOp "Compare opcode" - ColumnOption "column definition option" - ColumnOptionList "column definition option list" - ColumnOptionListOpt "optional column definition option list" - Constraint "table constraint" - ConstraintElem "table constraint element" - ConstraintKeywordOpt "Constraint Keyword or empty" - CreateDatabaseStmt "Create Database Statement" - CreateIndexStmt "CREATE INDEX statement" - CreateIndexStmtUnique "CREATE INDEX optional UNIQUE clause" - DatabaseOption "CREATE Database specification" - DatabaseOptionList "CREATE Database specification list" - DatabaseOptionListOpt "CREATE Database specification list opt" - CreateTableStmt "CREATE TABLE statement" - CreateUserStmt "CREATE User statement" - CrossOpt "Cross join option" - DatabaseSym "DATABASE or SCHEMA" - DBName "Database Name" - DeallocateSym "Deallocate or drop" - DeallocateStmt "Deallocate prepared statement" - Default "DEFAULT clause" - DefaultOpt "optional DEFAULT clause" - DefaultKwdOpt "optional DEFAULT keyword" - DefaultValueExpr "DefaultValueExpr(Now or Signed Literal)" - DeleteFromStmt "DELETE FROM statement" - DistinctOpt "Distinct option" - DoStmt "Do statement" - DropDatabaseStmt "DROP DATABASE statement" - DropIndexStmt "DROP INDEX statement" - DropTableStmt "DROP TABLE statement" - EmptyStmt "empty statement" - EqOpt "= or empty" - EscapedTableRef "escaped table reference" - ExecuteStmt "Execute statement" - ExplainSym "EXPLAIN or DESCRIBE or DESC" - ExplainStmt "EXPLAIN statement" - Expression "expression" - ExpressionList "expression list" - ExpressionListOpt "expression list opt" - ExpressionListList "expression list list" - Factor "expression factor" - PredicateExpr "Predicate expression factor" - Field "field expression" - FieldAsName "Field alias name" - FieldAsNameOpt "Field alias name opt" - FieldList "field expression list" - TableRefsClause "Table references clause" - Function "function expr" - FunctionCallAgg "Function call on aggregate data" - FunctionCallConflict "Function call with reserved keyword as function name" - FunctionCallKeyword "Function call with keyword as function name" - FunctionCallNonKeyword "Function call with nonkeyword as function name" - FunctionNameConflict "Built-in function call names which are conflict with keywords" - FuncDatetimePrec "Function datetime precision" - GlobalScope "The scope of variable" - GrantStmt "Grant statement" - GroupByClause "GROUP BY clause" - HashString "Hashed string" - HavingClause "HAVING clause" - IfExists "If Exists" - IfNotExists "If Not Exists" - IgnoreOptional "IGNORE or empty" - IndexColName "Index column name" - IndexColNameList "List of index column name" - IndexName "index name" - IndexType "index type" - InsertIntoStmt "INSERT INTO statement" - InsertValues "Rest part of INSERT/REPLACE INTO statement" - IntoOpt "INTO or EmptyString" - JoinTable "join table" - JoinType "join type" - KeyOrIndex "{KEY|INDEX}" - LikeEscapeOpt "like escape option" - LimitClause "LIMIT clause" - Literal "literal value" - logAnd "logical and operator" - logOr "logical or operator" - LowPriorityOptional "LOW_PRIORITY or empty" - name "name" - NationalOpt "National option" - NotOpt "optional NOT" - NowSym "CURRENT_TIMESTAMP/LOCALTIME/LOCALTIMESTAMP/NOW" - NumLiteral "Num/Int/Float/Decimal Literal" - ObjectType "Grant statement object type" - OnDuplicateKeyUpdate "ON DUPLICATE KEY UPDATE value list" - Operand "operand" - OptFull "Full or empty" - OptInteger "Optional Integer keyword" - Order "ORDER BY clause optional collation specification" - OrderBy "ORDER BY clause" - ByItem "BY item" - OrderByOptional "Optional ORDER BY clause optional" - ByList "BY list" - OuterOpt "optional OUTER clause" - QuickOptional "QUICK or empty" - PasswordOpt "Password option" - ColumnPosition "Column position [First|After ColumnName]" - PreparedStmt "PreparedStmt" - PrepareSQL "Prepare statement sql string" - PrimaryExpression "primary expression" - PrimaryFactor "primary expression factor" - Priority "insert statement priority" - PrivElem "Privilege element" - PrivElemList "Privilege element list" - PrivLevel "Privilege scope" - PrivType "Privilege type" - ReferDef "Reference definition" - RegexpSym "REGEXP or RLIKE" - ReplaceIntoStmt "REPLACE INTO statement" - ReplacePriority "replace statement priority" - RollbackStmt "ROLLBACK statement" - SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," - SelectStmt "SELECT statement" - SelectStmtCalcFoundRows "SELECT statement optional SQL_CALC_FOUND_ROWS" - SelectStmtDistinct "SELECT statement optional DISTINCT clause" - SelectStmtFieldList "SELECT statement field list" - SelectStmtLimit "SELECT statement optional LIMIT clause" - SelectStmtOpts "Select statement options" - SelectStmtGroup "SELECT statement optional GROUP BY clause" - SetStmt "Set variable statement" - ShowStmt "Show engines/databases/tables/columns/warnings statement" - ShowDatabaseNameOpt "Show tables/columns statement database name option" - ShowTableAliasOpt "Show table alias option" - ShowLikeOrWhereOpt "Show like or where condition option" - SignedLiteral "Literal or NumLiteral with sign" - Statement "statement" - StatementList "statement list" - StringName "string literal or identifier" - StringList "string list" - ExplainableStmt "explainable statement" - SubSelect "Sub Select" - Symbol "Constraint Symbol" - SystemVariable "System defined variable name" - TableAsName "table alias name" - TableAsNameOpt "table alias name optional" - TableElement "table definition element" - TableElementList "table definition element list" - TableFactor "table factor" - TableName "Table name" - TableNameList "Table name list" - TableOption "create table option" - TableOptionList "create table option list" - TableOptionListOpt "create table option list opt" - TableRef "table reference" - TableRefs "table references" - TimeUnit "Time unit" - TrimDirection "Trim string direction" - TruncateTableStmt "TRANSACTION TABLE statement" - UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionStmt "Union select state ment" - UnionClauseList "Union select clause list" - UnionSelect "Union (select) item" - UpdateStmt "UPDATE statement" - Username "Username" - UserSpec "Username and auth option" - UserSpecList "Username and auth option list" - UserVariable "User defined variable name" - UserVariableList "User defined variable name list" - UseStmt "USE statement" - ValueSym "Value or Values" - VariableAssignment "set variable value" - VariableAssignmentList "set variable value list" - Variable "User or system variable" - WhereClause "WHERE clause" - WhereClauseOptional "Optinal WHERE clause" - - Identifier "identifier or unreserved keyword" - UnReservedKeyword "MySQL unreserved keywords" - NotKeywordToken "Tokens not mysql keyword but treated specially" - - WhenClause "When clause" - WhenClauseList "When clause list" - ElseOpt "Optional else clause" - ExpressionOpt "Optional expression" - - Type "Types" - - NumericType "Numeric types" - IntegerType "Integer Types types" - FixedPointType "Exact value types" - FloatingPointType "Approximate value types" - BitValueType "bit value types" - - StringType "String types" - BlobType "Blob types" - TextType "Text types" - - DateAndTimeType "Date and Time types" - - OptFieldLen "Field length or empty" - FieldLen "Field length" - FieldOpts "Field type definition option list" - FieldOpt "Field type definition option" - FloatOpt "Floating-point type option" - Precision "Floating-point precision option" - OptBinary "Optional BINARY" - CharsetKw "charset or charater set" - OptCharset "Optional Character setting" - OptCollate "Optional Collate setting" - NUM "numbers" - LengthNum "Field length num(uint64)" - -%token tableRefPriority - -%precedence lowerThanCalcFoundRows -%precedence calcFoundRows - -%precedence lowerThanSetKeyword -%precedence set - -%precedence lowerThanInsertValues -%precedence insertValues - -%left join inner cross left right full -/* A dummy token to force the priority of TableRef production in a join. */ -%left tableRefPriority -%precedence on -%left oror or -%left xor -%left andand and -%left between -%precedence lowerThanEq -%left eq ge le neq neqSynonym '>' '<' is like in -%left '|' -%left '&' -%left rsh lsh -%left '-' '+' -%left '*' '/' '%' div mod -%left '^' -%left '~' neg -%right not -%right collate - -%precedence lowerThanLeftParen -%precedence '(' -%precedence lowerThanQuick -%precedence quick -%precedence lowerThanEscape -%precedence escape -%precedence lowerThanComma -%precedence ',' - -%start Start - -%% - -Start: - StatementList -| parseExpression Expression - { - yylex.(*lexer).expr = $2.(ast.ExprNode) - } - -/**************************************AlterTableStmt*************************************** - * See: https://dev.mysql.com/doc/refman/5.7/en/alter-table.html - *******************************************************************************************/ -AlterTableStmt: - "ALTER" IgnoreOptional "TABLE" TableName AlterTableSpecList - { - $$ = &ast.AlterTableStmt{ - Table: $4.(*ast.TableName), - Specs: $5.([]*ast.AlterTableSpec), - } - } - -AlterTableSpec: - TableOptionListOpt - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableOption, - Options:$1.([]*ast.TableOption), - } - } -| "ADD" ColumnKeywordOpt ColumnDef ColumnPosition - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableAddColumn, - Column: $3.(*ast.ColumnDef), - Position: $4.(*ast.ColumnPosition), - } - } -| "ADD" Constraint - { - constraint := $2.(*ast.Constraint) - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableAddConstraint, - Constraint: constraint, - } - } -| "DROP" ColumnKeywordOpt ColumnName - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableDropColumn, - DropColumn: $3.(*ast.ColumnName), - } - } -| "DROP" "PRIMARY" "KEY" - { - $$ = &ast.AlterTableSpec{Tp: ast.AlterTableDropPrimaryKey} - } -| "DROP" KeyOrIndex IndexName - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableDropIndex, - Name: $3.(string), - } - } -| "DROP" "FOREIGN" "KEY" Symbol - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableDropForeignKey, - Name: $4.(string), - } - } - -KeyOrIndex: - "KEY"|"INDEX" - -ColumnKeywordOpt: - {} -| "COLUMN" - -ColumnPosition: - { - $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionNone} - } -| "FIRST" - { - $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionFirst} - } -| "AFTER" ColumnName - { - $$ = &ast.ColumnPosition{ - Tp: ast.ColumnPositionAfter, - RelativeColumn: $2.(*ast.ColumnName), - } - } - -AlterTableSpecList: - AlterTableSpec - { - $$ = []*ast.AlterTableSpec{$1.(*ast.AlterTableSpec)} - } -| AlterTableSpecList ',' AlterTableSpec - { - $$ = append($1.([]*ast.AlterTableSpec), $3.(*ast.AlterTableSpec)) - } - -ConstraintKeywordOpt: - { - $$ = nil - } -| "CONSTRAINT" - { - $$ = nil - } -| "CONSTRAINT" Symbol - { - $$ = $2.(string) - } - -Symbol: - Identifier - -/*******************************************************************************************/ -Assignment: - ColumnName eq Expression - { - $$ = &ast.Assignment{Column: $1.(*ast.ColumnName), Expr:$3.(ast.ExprNode)} - } - -AssignmentList: - Assignment - { - $$ = []*ast.Assignment{$1.(*ast.Assignment)} - } -| AssignmentList ',' Assignment - { - $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) - } - -AssignmentListOpt: - /* EMPTY */ - { - $$ = []*ast.Assignment{} - } -| AssignmentList - -BeginTransactionStmt: - "BEGIN" - { - $$ = &ast.BeginStmt{} - } -| "START" "TRANSACTION" - { - $$ = &ast.BeginStmt{} - } - -ColumnDef: - ColumnName Type ColumnOptionListOpt - { - $$ = &ast.ColumnDef{Name: $1.(*ast.ColumnName), Tp: $2.(*types.FieldType), Options: $3.([]*ast.ColumnOption)} - } - -ColumnName: - Identifier - { - $$ = &ast.ColumnName{Name: model.NewCIStr($1.(string))} - } -| Identifier '.' Identifier - { - $$ = &ast.ColumnName{Table: model.NewCIStr($1.(string)), Name: model.NewCIStr($3.(string))} - } -| Identifier '.' Identifier '.' Identifier - { - $$ = &ast.ColumnName{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string)), Name: model.NewCIStr($5.(string))} - } - -ColumnNameList: - ColumnName - { - $$ = []*ast.ColumnName{$1.(*ast.ColumnName)} - } -| ColumnNameList ',' ColumnName - { - $$ = append($1.([]*ast.ColumnName), $3.(*ast.ColumnName)) - } - -ColumnNameListOpt: - /* EMPTY */ - { - $$ = []*ast.ColumnName{} - } -| ColumnNameList - { - $$ = $1.([]*ast.ColumnName) - } - -CommitStmt: - "COMMIT" - { - $$ = &ast.CommitStmt{} - } - -ColumnOption: - "NOT" "NULL" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNotNull} - } -| "NULL" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNull} - } -| "AUTO_INCREMENT" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionAutoIncrement} - } -| "PRIMARY" "KEY" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionPrimaryKey} - } -| "UNIQUE" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniq} - } -| "UNIQUE" "KEY" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniqKey} - } -| "DEFAULT" DefaultValueExpr - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionDefaultValue, Expr: $2.(ast.ExprNode)} - } -| "ON" "UPDATE" NowSym - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionOnUpdate, Expr: $3.(ast.ExprNode)} - } -| "COMMENT" stringLit - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionComment} - } -| "CHECK" '(' Expression ')' - { - // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html - // The CHECK clause is parsed but ignored by all storage engines. - $$ = &ast.ColumnOption{} - } - -ColumnOptionList: - ColumnOption - { - $$ = []*ast.ColumnOption{$1.(*ast.ColumnOption)} - } -| ColumnOptionList ColumnOption - { - $$ = append($1.([]*ast.ColumnOption), $2.(*ast.ColumnOption)) - } - -ColumnOptionListOpt: - { - $$ = []*ast.ColumnOption{} - } -| ColumnOptionList - { - $$ = $1.([]*ast.ColumnOption) - } - -ConstraintElem: - "PRIMARY" "KEY" '(' IndexColNameList ')' - { - $$ = &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: $4.([]*ast.IndexColName)} - } -| "FULLTEXT" "KEY" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintFulltext, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string), - } - } -| "INDEX" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintIndex, - Keys: $4.([]*ast.IndexColName), - Name: $2.(string), - } - } -| "KEY" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintKey, - Keys: $4.([]*ast.IndexColName), - Name: $2.(string)} - } -| "UNIQUE" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintUniq, - Keys: $4.([]*ast.IndexColName), - Name: $2.(string)} - } -| "UNIQUE" "INDEX" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintUniqIndex, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string)} - } -| "UNIQUE" "KEY" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintUniqKey, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string)} - } -| "FOREIGN" "KEY" IndexName '(' IndexColNameList ')' ReferDef - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintForeignKey, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string), - Refer: $7.(*ast.ReferenceDef), - } - } - -ReferDef: - "REFERENCES" TableName '(' IndexColNameList ')' - { - $$ = &ast.ReferenceDef{Table: $2.(*ast.TableName), IndexColNames: $4.([]*ast.IndexColName)} - } - -/* - * The DEFAULT clause specifies a default value for a column. - * With one exception, the default value must be a constant; - * it cannot be a function or an expression. This means, for example, - * that you cannot set the default for a date column to be the value of - * a function such as NOW() or CURRENT_DATE. The exception is that you - * can specify CURRENT_TIMESTAMP as the default for a TIMESTAMP or DATETIME column. - * - * See: http://dev.mysql.com/doc/refman/5.7/en/create-table.html - * https://github.com/mysql/mysql-server/blob/5.7/sql/sql_yacc.yy#L6832 - */ -DefaultValueExpr: - NowSym -| SignedLiteral - -// TODO: Process other three keywords -NowSym: - "CURRENT_TIMESTAMP" - { - $$ = &ast.IdentifierExpr{Name: model.NewCIStr("CURRENT_TIMESTAMP")} - } -| "LOCALTIME" -| "LOCALTIMESTAMP" -| "NOW" - -SignedLiteral: - Literal - { - $$ = ast.NewValueExpr($1) - } -| '+' NumLiteral - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: ast.NewValueExpr($2)} - } -| '-' NumLiteral - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr($2)} - } - -// TODO: support decimal literal -NumLiteral: - intLit -| floatLit - - -CreateIndexStmt: - "CREATE" CreateIndexStmtUnique "INDEX" Identifier "ON" TableName '(' IndexColNameList ')' - { - $$ = &ast.CreateIndexStmt{ - Unique: $2.(bool), - IndexName: $4.(string), - Table: $6.(*ast.TableName), - IndexColNames: $8.([]*ast.IndexColName), - } - if yylex.(*lexer).root { - break - } - } - -CreateIndexStmtUnique: - { - $$ = false - } -| "UNIQUE" - { - $$ = true - } - -IndexColName: - ColumnName OptFieldLen Order - { - //Order is parsed but just ignored as MySQL did - $$ = &ast.IndexColName{Column: $1.(*ast.ColumnName), Length: $2.(int)} - } - -IndexColNameList: - { - $$ = []*ast.IndexColName{} - } -| IndexColName - { - $$ = []*ast.IndexColName{$1.(*ast.IndexColName)} - } -| IndexColNameList ',' IndexColName - { - $$ = append($1.([]*ast.IndexColName), $3.(*ast.IndexColName)) - } - - - -/******************************************************************* - * - * Create Database Statement - * CREATE {DATABASE | SCHEMA} [IF NOT EXISTS] db_name - * [create_specification] ... - * - * create_specification: - * [DEFAULT] CHARACTER SET [=] charset_name - * | [DEFAULT] COLLATE [=] collation_name - *******************************************************************/ -CreateDatabaseStmt: - "CREATE" DatabaseSym IfNotExists DBName DatabaseOptionListOpt - { - $$ = &ast.CreateDatabaseStmt{ - IfNotExists: $3.(bool), - Name: $4.(string), - Options: $5.([]*ast.DatabaseOption), - } - - if yylex.(*lexer).root { - break - } - } - -DBName: - Identifier - -DatabaseOption: - DefaultKwdOpt CharsetKw EqOpt StringName - { - $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCharset, Value: $4.(string)} - } -| DefaultKwdOpt "COLLATE" EqOpt StringName - { - $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: $4.(string)} - } - -DatabaseOptionListOpt: - { - $$ = []*ast.DatabaseOption{} - } -| DatabaseOptionList - -DatabaseOptionList: - DatabaseOption - { - $$ = []*ast.DatabaseOption{$1.(*ast.DatabaseOption)} - } -| DatabaseOptionList DatabaseOption - { - $$ = append($1.([]*ast.DatabaseOption), $2.(*ast.DatabaseOption)) - } - -/******************************************************************* - * - * Create Table Statement - * - * Example: - * CREATE TABLE Persons - * ( - * P_Id int NOT NULL, - * LastName varchar(255) NOT NULL, - * FirstName varchar(255), - * Address varchar(255), - * City varchar(255), - * PRIMARY KEY (P_Id) - * ) - *******************************************************************/ -CreateTableStmt: - "CREATE" "TABLE" IfNotExists TableName '(' TableElementList ')' TableOptionListOpt - { - tes := $6.([]interface {}) - var columnDefs []*ast.ColumnDef - var constraints []*ast.Constraint - for _, te := range tes { - switch te := te.(type) { - case *ast.ColumnDef: - columnDefs = append(columnDefs, te) - case *ast.Constraint: - constraints = append(constraints, te) - } - } - if len(columnDefs) == 0 { - yylex.(*lexer).err("Column Definition List can't be empty.") - return 1 - } - $$ = &ast.CreateTableStmt{ - Table: $4.(*ast.TableName), - IfNotExists: $3.(bool), - Cols: columnDefs, - Constraints: constraints, - Options: $8.([]*ast.TableOption), - } - } - -Default: - "DEFAULT" Expression - { - $$ = $2 - } - -DefaultOpt: - { - $$ = nil - } -| Default - -DefaultKwdOpt: - {} -| "DEFAULT" - -/****************************************************************** - * Do statement - * See: https://dev.mysql.com/doc/refman/5.7/en/do.html - ******************************************************************/ -DoStmt: - "DO" ExpressionList - { - $$ = &ast.DoStmt { - Exprs: $2.([]ast.ExprNode), - } - } - -/******************************************************************* - * - * Delete Statement - * - *******************************************************************/ -DeleteFromStmt: - "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause - { - // Single Table - join := &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil} - x := &ast.DeleteStmt{ - TableRefs: &ast.TableRefsClause{TableRefs: join}, - LowPriority: $2.(bool), - Quick: $3.(bool), - Ignore: $4.(bool), - } - if $7 != nil { - x.Where = $7.(ast.ExprNode) - } - if $8 != nil { - x.Order = $8.(*ast.OrderByClause) - } - if $9 != nil { - x.Limit = $9.(*ast.Limit) - } - - $$ = x - if yylex.(*lexer).root { - break - } - } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional TableNameList "FROM" TableRefs WhereClauseOptional - { - // Multiple Table - x := &ast.DeleteStmt{ - LowPriority: $2.(bool), - Quick: $3.(bool), - Ignore: $4.(bool), - MultiTable: true, - BeforeFrom: true, - Tables: $5.([]*ast.TableName), - TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, - } - if $8 != nil { - x.Where = $8.(ast.ExprNode) - } - $$ = x - if yylex.(*lexer).root { - break - } - } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableNameList "USING" TableRefs WhereClauseOptional - { - // Multiple Table - x := &ast.DeleteStmt{ - LowPriority: $2.(bool), - Quick: $3.(bool), - Ignore: $4.(bool), - MultiTable: true, - Tables: $6.([]*ast.TableName), - TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, - } - if $9 != nil { - x.Where = $9.(ast.ExprNode) - } - $$ = x - if yylex.(*lexer).root { - break - } - } - -DatabaseSym: - "DATABASE" | "SCHEMA" - -DropDatabaseStmt: - "DROP" DatabaseSym IfExists DBName - { - $$ = &ast.DropDatabaseStmt{IfExists: $3.(bool), Name: $4.(string)} - if yylex.(*lexer).root { - break - } - } - -DropIndexStmt: - "DROP" "INDEX" IfExists Identifier "ON" TableName - { - $$ = &ast.DropIndexStmt{IfExists: $3.(bool), IndexName: $4.(string), Table: $6.(*ast.TableName)} - } - -DropTableStmt: - "DROP" TableOrTables TableNameList - { - $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} - if yylex.(*lexer).root { - break - } - } -| "DROP" TableOrTables "IF" "EXISTS" TableNameList - { - $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} - if yylex.(*lexer).root { - break - } - } - -TableOrTables: - "TABLE" -| "TABLES" - -EqOpt: - { - } -| eq - { - } - -EmptyStmt: - /* EMPTY */ - { - $$ = nil - } - -ExplainSym: - "EXPLAIN" -| "DESCRIBE" -| "DESC" - -ExplainStmt: - ExplainSym TableName - { - $$ = &ast.ExplainStmt{ - Stmt: &ast.ShowStmt{ - Tp: ast.ShowColumns, - Table: $2.(*ast.TableName), - }, - } - } -| ExplainSym TableName ColumnName - { - $$ = &ast.ExplainStmt{ - Stmt: &ast.ShowStmt{ - Tp: ast.ShowColumns, - Table: $2.(*ast.TableName), - Column: $3.(*ast.ColumnName), - }, - } - } -| ExplainSym ExplainableStmt - { - $$ = &ast.ExplainStmt{Stmt: $2.(ast.StmtNode)} - } - -LengthNum: - NUM - { - switch v := $1.(type) { - case int64: - $$ = uint64(v) - case uint64: - $$ = uint64(v) - } - } - -NUM: - intLit - -Expression: - Expression logOr Expression %prec oror - { - $$ = &ast.BinaryOperationExpr{Op: opcode.OrOr, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| Expression "XOR" Expression %prec xor - { - $$ = &ast.BinaryOperationExpr{Op: opcode.LogicXor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| Expression logAnd Expression %prec andand - { - $$ = &ast.BinaryOperationExpr{Op: opcode.AndAnd, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| "NOT" Expression %prec not - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} - } -| Factor "IS" NotOpt trueKwd %prec is - { - $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(1)} - } -| Factor "IS" NotOpt falseKwd %prec is - { - $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(0)} - } -| Factor "IS" NotOpt "UNKNOWN" %prec is - { - /* https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#operator_is */ - $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} - } -| Factor - - -logOr: - "||" - { - } -| "OR" - { - } - -logAnd: - "&&" - { - } -| "AND" - { - } - -name: - Identifier - -ExpressionList: - Expression - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| ExpressionList ',' Expression - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) - } - -ExpressionListOpt: - { - $$ = []ast.ExprNode{} - } -| ExpressionList - -Factor: - Factor "IS" NotOpt "NULL" %prec is - { - $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} - } -| Factor CompareOp PredicateExpr %prec eq - { - $$ = &ast.BinaryOperationExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| Factor CompareOp AnyOrAll SubSelect %prec eq - { - $$ = &ast.CompareSubqueryExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $4.(*ast.SubqueryExpr), All: $3.(bool)} - } -| PredicateExpr - -CompareOp: - ">=" - { - $$ = opcode.GE - } -| '>' - { - $$ = opcode.GT - } -| "<=" - { - $$ = opcode.LE - } -| '<' - { - $$ = opcode.LT - } -| "!=" - { - $$ = opcode.NE - } -| "<>" - { - $$ = opcode.NE - } -| "=" - { - $$ = opcode.EQ - } -| "<=>" - { - $$ = opcode.NullEQ - } - -AnyOrAll: - "ANY" - { - $$ = false - } -| "SOME" - { - $$ = false - } -| "ALL" - { - $$ = true - } - -PredicateExpr: - PrimaryFactor NotOpt "IN" '(' ExpressionList ')' - { - $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), List: $5.([]ast.ExprNode)} - } -| PrimaryFactor NotOpt "IN" SubSelect - { - $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), Sel: $4.(*ast.SubqueryExpr)} - } -| PrimaryFactor NotOpt "BETWEEN" PrimaryFactor "AND" PredicateExpr - { - $$ = &ast.BetweenExpr{ - Expr: $1.(ast.ExprNode), - Left: $4.(ast.ExprNode), - Right: $6.(ast.ExprNode), - Not: $2.(bool), - } - } -| PrimaryFactor NotOpt "LIKE" PrimaryExpression LikeEscapeOpt - { - escape := $5.(string) - if len(escape) > 1 { - yylex.(*lexer).errf("Incorrect arguments %s to ESCAPE", escape) - return 1 - } else if len(escape) == 0 { - escape = "\\" - } - $$ = &ast.PatternLikeExpr{ - Expr: $1.(ast.ExprNode), - Pattern: $4.(ast.ExprNode), - Not: $2.(bool), - Escape: escape[0], - } - } -| PrimaryFactor NotOpt RegexpSym PrimaryExpression - { - $$ = &ast.PatternRegexpExpr{Expr: $1.(ast.ExprNode), Pattern: $4.(ast.ExprNode), Not: $2.(bool)} - } -| PrimaryFactor - -RegexpSym: - "REGEXP" -| "RLIKE" - -LikeEscapeOpt: - %prec lowerThanEscape - { - $$ = "\\" - } -| "ESCAPE" stringLit - { - $$ = $2 - } - -NotOpt: - { - $$ = false - } -| "NOT" - { - $$ = true - } - -Field: - '*' - { - $$ = &ast.SelectField{WildCard: &ast.WildCardField{}} - } -| Identifier '.' '*' - { - wildCard := &ast.WildCardField{Table: model.NewCIStr($1.(string))} - $$ = &ast.SelectField{WildCard: wildCard} - } -| Identifier '.' Identifier '.' '*' - { - wildCard := &ast.WildCardField{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string))} - $$ = &ast.SelectField{WildCard: wildCard} - } -| Expression FieldAsNameOpt - { - expr := $1.(ast.ExprNode) - asName := $2.(string) - if asName != "" { - // Set expr original text. - offset := yyS[yypt-1].offset - end := yyS[yypt].offset-1 - expr.SetText(yylex.(*lexer).src[offset:end]) - } - $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} - } - -FieldAsNameOpt: - /* EMPTY */ - { - $$ = "" - } -| FieldAsName - { - $$ = $1 - } - -FieldAsName: - Identifier - { - $$ = $1 - } -| "AS" Identifier - { - $$ = $2 - } -| stringLit - { - $$ = $1 - } -| "AS" stringLit - { - $$ = $2 - } - -FieldList: - Field - { - field := $1.(*ast.SelectField) - field.Offset = yyS[yypt].offset - $$ = []*ast.SelectField{field} - } -| FieldList ',' Field - { - - fl := $1.([]*ast.SelectField) - last := fl[len(fl)-1] - if last.Expr != nil && last.AsName.O == "" { - lastEnd := yyS[yypt-1].offset-1 // Comma offset. - last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) - } - newField := $3.(*ast.SelectField) - newField.Offset = yyS[yypt].offset - $$ = append(fl, newField) - } - -GroupByClause: - "GROUP" "BY" ByList - { - $$ = &ast.GroupByClause{Items: $3.([]*ast.ByItem)} - } - -HavingClause: - { - $$ = nil - } -| "HAVING" Expression - { - $$ = &ast.HavingClause{Expr: $2.(ast.ExprNode)} - } - -IfExists: - { - $$ = false - } -| "IF" "EXISTS" - { - $$ = true - } - -IfNotExists: - { - $$ = false - } -| "IF" "NOT" "EXISTS" - { - $$ = true - } - - -IgnoreOptional: - { - $$ = false - } -| "IGNORE" - { - $$ = true - } - -IndexName: - { - $$ = "" - } -| Identifier - { - //"index name" - $$ = $1.(string) - } - -IndexType: - Identifier - { - // TODO: "index type" - $$ = $1.(string) - } - -/**********************************Identifier********************************************/ -Identifier: - identifier | UnReservedKeyword | NotKeywordToken - -UnReservedKeyword: - "AUTO_INCREMENT" | "AFTER" | "AVG" | "BEGIN" | "BIT" | "BOOL" | "BOOLEAN" | "CHARSET" | "COLUMNS" | "COMMIT" -| "DATE" | "DATETIME" | "DEALLOCATE" | "DO" | "END" | "ENGINE" | "ENGINES" | "EXECUTE" | "FIRST" | "FULL" -| "LOCAL" | "NAMES" | "OFFSET" | "PASSWORD" %prec lowerThanEq | "PREPARE" | "QUICK" | "ROLLBACK" | "SESSION" | "SIGNED" -| "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" -| "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" -| "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" - -NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" -| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" -| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" - -/************************************************************************************ - * - * Insert Statments - * - * TODO: support PARTITION - **********************************************************************************/ -InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableName InsertValues OnDuplicateKeyUpdate - { - x := $6.(*ast.InsertStmt) - x.Priority = $2.(int) - // Wraps many layers here so that it can be processed the same way as select statement. - ts := &ast.TableSource{Source: $5.(*ast.TableName)} - x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} - if $7 != nil { - x.OnDuplicate = $7.([]*ast.Assignment) - } - $$ = x - if yylex.(*lexer).root { - break - } - } - -IntoOpt: - { - } -| "INTO" - { - } - -InsertValues: - '(' ColumnNameListOpt ')' ValueSym ExpressionListList - { - $$ = &ast.InsertStmt{ - Columns: $2.([]*ast.ColumnName), - Lists: $5.([][]ast.ExprNode), - } - } -| '(' ColumnNameListOpt ')' SelectStmt - { - $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} - } -| '(' ColumnNameListOpt ')' UnionStmt - { - $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.UnionStmt)} - } -| ValueSym ExpressionListList %prec insertValues - { - $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} - } -| SelectStmt - { - $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} - } -| UnionStmt - { - $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} - } -| "SET" ColumnSetValueList - { - $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} - } - -ValueSym: - "VALUE" -| "VALUES" - -ExpressionListList: - '(' ')' - { - $$ = [][]ast.ExprNode{[]ast.ExprNode{}} - } -| '(' ')' ',' ExpressionListList - { - $$ = append([][]ast.ExprNode{[]ast.ExprNode{}}, $4.([][]ast.ExprNode)...) - } -| '(' ExpressionList ')' - { - $$ = [][]ast.ExprNode{$2.([]ast.ExprNode)} - } -| '(' ExpressionList ')' ',' ExpressionListList - { - $$ = append([][]ast.ExprNode{$2.([]ast.ExprNode)}, $5.([][]ast.ExprNode)...) - } - -ColumnSetValue: - ColumnName eq Expression - { - $$ = &ast.Assignment{ - Column: $1.(*ast.ColumnName), - Expr: $3.(ast.ExprNode), - } - } - -ColumnSetValueList: - { - $$ = []*ast.Assignment{} - } -| ColumnSetValue - { - $$ = []*ast.Assignment{$1.(*ast.Assignment)} - } -| ColumnSetValueList ',' ColumnSetValue - { - $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) - } - -/* - * ON DUPLICATE KEY UPDATE col_name=expr [, col_name=expr] ... - * See: https://dev.mysql.com/doc/refman/5.7/en/insert-on-duplicate.html - */ -OnDuplicateKeyUpdate: - { - $$ = nil - } -| "ON" "DUPLICATE" "KEY" "UPDATE" AssignmentList - { - $$ = $5 - } - -/***********************************Insert Statements END************************************/ - -/************************************************************************************ - * Replace Statements - * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html - * - * TODO: support PARTITION - **********************************************************************************/ -ReplaceIntoStmt: - "REPLACE" ReplacePriority IntoOpt TableName InsertValues - { - x := $5.(*ast.InsertStmt) - x.Replace = true - x.Priority = $2.(int) - ts := &ast.TableSource{Source: $4.(*ast.TableName)} - x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} - $$ = x - } - -ReplacePriority: - { - $$ = ast.NoPriority - } -| "LOW_PRIORITY" - { - $$ = ast.LowPriority - } -| "DELAYED" - { - $$ = ast.DelayedPriority - } - -/***********************************Replace Statments END************************************/ - -Literal: - "false" - { - $$ = int64(0) - } -| "NULL" -| "true" - { - $$ = int64(1) - } -| floatLit -| intLit -| stringLit - { - tp := types.NewFieldType(mysql.TypeString) - l := yylex.(*lexer) - tp.Charset, tp.Collate = l.GetCharsetInfo() - $$ = &types.DataItem{ - Type: tp, - Data: $1.(string), - } - } -| "UNDERSCORE_CHARSET" stringLit - { - // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html - tp := types.NewFieldType(mysql.TypeString) - tp.Charset = $1.(string) - co, err := charset.GetDefaultCollation(tp.Charset) - if err != nil { - l := yylex.(*lexer) - l.errf("Get collation error for charset: %s", tp.Charset) - return 1 - } - tp.Collate = co - $$ = &types.DataItem{ - Type: tp, - Data: $2.(string), - } - } -| hexLit -| bitLit - -Operand: - Literal - { - $$ = ast.NewValueExpr($1) - } -| ColumnName - { - $$ = &ast.ColumnNameExpr{Name: $1.(*ast.ColumnName)} - } -| '(' Expression ')' - { - $$ = &ast.ParenthesesExpr{Expr: $2.(ast.ExprNode)} - } -| "DEFAULT" %prec lowerThanLeftParen - { - $$ = &ast.DefaultExpr{} - } -| "DEFAULT" '(' ColumnName ')' - { - $$ = &ast.DefaultExpr{Name: $3.(*ast.ColumnName)} - } -| Variable - { - $$ = $1 - } -| "PLACEHOLDER" - { - $$ = &ast.ParamMarkerExpr{ - Offset: yyS[yypt].offset, - } - } -| "ROW" '(' Expression ',' ExpressionList ')' - { - values := append([]ast.ExprNode{$3.(ast.ExprNode)}, $5.([]ast.ExprNode)...) - $$ = &ast.RowExpr{Values: values} - } -| '(' Expression ',' ExpressionList ')' - { - values := append([]ast.ExprNode{$2.(ast.ExprNode)}, $4.([]ast.ExprNode)...) - $$ = &ast.RowExpr{Values: values} - } -| "EXISTS" SubSelect - { - $$ = &ast.ExistsSubqueryExpr{Sel: $2.(*ast.SubqueryExpr)} - } - -OrderBy: - "ORDER" "BY" ByList - { - $$ = &ast.OrderByClause{Items: $3.([]*ast.ByItem)} - } - -ByList: - ByItem - { - $$ = []*ast.ByItem{$1.(*ast.ByItem)} - } -| ByList ',' ByItem - { - $$ = append($1.([]*ast.ByItem), $3.(*ast.ByItem)) - } - -ByItem: - Expression Order - { - $$ = &ast.ByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} - } - -Order: - /* EMPTY */ - { - $$ = false // ASC by default - } -| "ASC" - { - $$ = false - } -| "DESC" - { - $$ = true - } - -OrderByOptional: - { - $$ = nil - } -| OrderBy - { - $$ = $1 - } - -PrimaryExpression: - Operand -| Function -| SubSelect -| '!' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} - } -| '~' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.BitNeg, V: $2.(ast.ExprNode)} - } -| '-' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: $2.(ast.ExprNode)} - } -| '+' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: $2.(ast.ExprNode)} - } -| "BINARY" PrimaryExpression %prec neg - { - // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#operator_binary - x := types.NewFieldType(mysql.TypeString) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = &ast.FuncCastExpr{ - Expr: $2.(ast.ExprNode), - Tp: x, - FunctionType: ast.CastBinaryOperator, - } - } -| PrimaryExpression "COLLATE" StringName %prec neg - { - // TODO: Create a builtin function hold expr and collation. When do evaluation, convert expr result using the collation. - $$ = $1 - } - -Function: - FunctionCallKeyword -| FunctionCallNonKeyword -| FunctionCallConflict -| FunctionCallAgg - -FunctionNameConflict: - "DATABASE" | "SCHEMA" | "IF" | "LEFT" | "REPEAT" | "CURRENT_USER" | "CURRENT_DATE" - -FunctionCallConflict: - FunctionNameConflict '(' ExpressionListOpt ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "CURRENT_USER" - { - // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - $$ = &ast.FuncCallExpr{FnName: $1.(string)} - } -| "CURRENT_DATE" - { - $$ = &ast.FuncCallExpr{FnName: $1.(string)} - } - -DistinctOpt: - { - $$ = false - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } -| "DISTINCT" "ALL" - { - $$ = true - } - -FunctionCallKeyword: - "CAST" '(' Expression "AS" CastType ')' - { - /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ - $$ = &ast.FuncCastExpr{ - Expr: $3.(ast.ExprNode), - Tp: $5.(*types.FieldType), - FunctionType: ast.CastFunction, - } - } -| "CASE" ExpressionOpt WhenClauseList ElseOpt "END" - { - x := &ast.CaseExpr{WhenClauses: $3.([]*ast.WhenClause)} - if $2 != nil { - x.Value = $2.(ast.ExprNode) - } - if $4 != nil { - x.ElseClause = $4.(ast.ExprNode) - } - $$ = x - } -| "CONVERT" '(' Expression "USING" StringName ')' - { - // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &ast.FuncConvertExpr{ - Expr: $3.(ast.ExprNode), - Charset: $5.(string), - } - } -| "CONVERT" '(' Expression ',' CastType ')' - { - // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &ast.FuncCastExpr{ - Expr: $3.(ast.ExprNode), - Tp: $5.(*types.FieldType), - FunctionType: ast.CastConvertFunction, - } - } -| "DATE" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "USER" '(' ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string)} - } -| "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues - { - // TODO: support qualified identifier for column_name - $$ = &ast.ValuesExpr{Column: $3.(*ast.ColumnName)} - } -| "WEEK" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "YEAR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } - -FunctionCallNonKeyword: - "COALESCE" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "CURDATE" '(' ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string)} - } -| "CURRENT_TIMESTAMP" FuncDatetimePrec - { - args := []ast.ExprNode{} - if $2 != nil { - args = append(args, $2.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} - } -| "ABS" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "CONCAT" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "CONCAT_WS" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "DAY" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DAYOFWEEK" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DAYOFMONTH" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DAYOFYEAR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' - { - $$ = &ast.FuncDateArithExpr{ - Op:ast.DateAdd, - Unit: $7.(string), - Date: $3.(ast.ExprNode), - Interval: $6.(ast.ExprNode), - } - } -| "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' - { - $$ = &ast.FuncDateArithExpr{ - Op:ast.DateSub, - Unit: $7.(string), - Date: $3.(ast.ExprNode), - Interval: $6.(ast.ExprNode), - } - } -| "EXTRACT" '(' TimeUnit "FROM" Expression ')' - { - $$ = &ast.FuncExtractExpr{ - Unit: $3.(string), - Date: $5.(ast.ExprNode), - } - } -| "FOUND_ROWS" '(' ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string)} - } -| "HOUR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "IFNULL" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "LENGTH" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "LOCATE" '(' Expression ',' Expression ')' - { - $$ = &ast.FuncLocateExpr{ - SubStr: $3.(ast.ExprNode), - Str: $5.(ast.ExprNode), - } - } -| "LOCATE" '(' Expression ',' Expression ',' Expression ')' - { - $$ = &ast.FuncLocateExpr{ - SubStr: $3.(ast.ExprNode), - Str: $5.(ast.ExprNode), - Pos: $7.(ast.ExprNode), - } - } -| "LOWER" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "MICROSECOND" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "MINUTE" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "MONTH" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "NOW" '(' ExpressionOpt ')' - { - args := []ast.ExprNode{} - if $3 != nil { - args = append(args, $3.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} - } -| "NULLIF" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "RAND" '(' ExpressionOpt ')' - { - - args := []ast.ExprNode{} - if $3 != nil { - args = append(args, $3.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} - } -| "REPLACE" '(' Expression ',' Expression ',' Expression ')' - { - args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} - } -| "SECOND" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "SUBSTRING" '(' Expression ',' Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - } - } -| "SUBSTRING" '(' Expression "FROM" Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - } - } -| "SUBSTRING" '(' Expression ',' Expression ',' Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - Len: $7.(ast.ExprNode), - } - } -| "SUBSTRING" '(' Expression "FROM" Expression "FOR" Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - Len: $7.(ast.ExprNode), - } - } -| "SUBSTRING_INDEX" '(' Expression ',' Expression ',' Expression ')' - { - $$ = &ast.FuncSubstringIndexExpr{ - StrExpr: $3.(ast.ExprNode), - Delim: $5.(ast.ExprNode), - Count: $7.(ast.ExprNode), - } - } -| "SYSDATE" '(' ExpressionOpt ')' - { - args := []ast.ExprNode{} - if $3 != nil { - args = append(args, $3.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} - } -| "TRIM" '(' Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $3.(ast.ExprNode), - } - } -| "TRIM" '(' Expression "FROM" Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $5.(ast.ExprNode), - RemStr: $3.(ast.ExprNode), - } - } -| "TRIM" '(' TrimDirection "FROM" Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $5.(ast.ExprNode), - Direction: $3.(ast.TrimDirectionType), - } - } -| "TRIM" '(' TrimDirection Expression "FROM" Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $6.(ast.ExprNode), - RemStr: $4.(ast.ExprNode), - Direction: $3.(ast.TrimDirectionType), - } - } -| "UPPER" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "WEEKDAY" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "WEEKOFYEAR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "YEARWEEK" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} - } - -TrimDirection: - "BOTH" - { - $$ = ast.TrimBoth - } -| "LEADING" - { - $$ = ast.TrimLeading - } -| "TRAILING" - { - $$ = ast.TrimTrailing - } - -FunctionCallAgg: - "AVG" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "COUNT" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "COUNT" '(' DistinctOpt '*' ')' - { - args := []ast.ExprNode{ast.NewValueExpr(ast.UnquoteString("*"))} - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} - } -| "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "MAX" '(' DistinctOpt Expression ')' - { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} - } -| "MIN" '(' DistinctOpt Expression ')' - { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} - } -| "SUM" '(' DistinctOpt Expression ')' - { - $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} - } - -FuncDatetimePrec: - { - $$ = nil - } -| '(' ')' - { - $$ = nil - } -| '(' Expression ')' - { - $$ = $2 - } - -TimeUnit: - "MICROSECOND" | "SECOND" | "MINUTE" | "HOUR" | "DAY" | "WEEK" -| "MONTH" | "QUARTER" | "YEAR" | "SECOND_MICROSECOND" | "MINUTE_MICROSECOND" -| "MINUTE_SECOND" | "HOUR_MICROSECOND" | "HOUR_SECOND" | "HOUR_MINUTE" -| "DAY_MICROSECOND" | "DAY_SECOND" | "DAY_MINUTE" | "DAY_HOUR" | "YEAR_MONTH" - -ExpressionOpt: - { - $$ = nil - } -| Expression - { - $$ = $1 - } - -WhenClauseList: - WhenClause - { - $$ = []*ast.WhenClause{$1.(*ast.WhenClause)} - } -| WhenClauseList WhenClause - { - $$ = append($1.([]*ast.WhenClause), $2.(*ast.WhenClause)) - } - -WhenClause: - "WHEN" Expression "THEN" Expression - { - $$ = &ast.WhenClause{ - Expr: $2.(ast.ExprNode), - Result: $4.(ast.ExprNode), - } - } - -ElseOpt: - /* empty */ - { - $$ = nil - } -| "ELSE" Expression - { - $$ = $2 - } - -CastType: - "BINARY" OptFieldLen - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "CHAR" OptFieldLen OptBinary OptCharset - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $2.(int) - if $3.(bool) { - x.Flag |= mysql.BinaryFlag - } - x.Charset = $4.(string) - $$ = x - } -| "DATE" - { - x := types.NewFieldType(mysql.TypeDate) - $$ = x - } -| "DATETIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDatetime) - x.Decimal = $2.(int) - $$ = x - } -| "DECIMAL" FloatOpt - { - fopt := $2.(*ast.FloatOpt) - x := types.NewFieldType(mysql.TypeNewDecimal) - x.Flen = fopt.Flen - x.Decimal = fopt.Decimal - $$ = x - } -| "TIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDuration) - x.Decimal = $2.(int) - $$ = x - } -| "SIGNED" OptInteger - { - x := types.NewFieldType(mysql.TypeLonglong) - $$ = x - } -| "UNSIGNED" OptInteger - { - x := types.NewFieldType(mysql.TypeLonglong) - x.Flag |= mysql.UnsignedFlag - $$ = x - } - - -PrimaryFactor: - PrimaryFactor '|' PrimaryFactor %prec '|' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Or, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '&' PrimaryFactor %prec '&' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.And, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor "<<" PrimaryFactor %prec lsh - { - $$ = &ast.BinaryOperationExpr{Op: opcode.LeftShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor ">>" PrimaryFactor %prec rsh - { - $$ = &ast.BinaryOperationExpr{Op: opcode.RightShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '+' PrimaryFactor %prec '+' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Plus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '-' PrimaryFactor %prec '-' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Minus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '*' PrimaryFactor %prec '*' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Mul, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '/' PrimaryFactor %prec '/' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Div, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '%' PrimaryFactor %prec '%' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor "DIV" PrimaryFactor %prec div - { - $$ = &ast.BinaryOperationExpr{Op: opcode.IntDiv, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor "MOD" PrimaryFactor %prec mod - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '^' PrimaryFactor - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Xor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryExpression - - -Priority: - { - $$ = ast.NoPriority - } -| "LOW_PRIORITY" - { - $$ = ast.LowPriority - } -| "HIGH_PRIORITY" - { - $$ = ast.HighPriority - } -| "DELAYED" - { - $$ = ast.DelayedPriority - } - -LowPriorityOptional: - { - $$ = false - } -| "LOW_PRIORITY" - { - $$ = true - } - -TableName: - Identifier - { - $$ = &ast.TableName{Name:model.NewCIStr($1.(string))} - } -| Identifier '.' Identifier - { - $$ = &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - } - -TableNameList: - TableName - { - tbl := []*ast.TableName{$1.(*ast.TableName)} - $$ = tbl - } -| TableNameList ',' TableName - { - $$ = append($1.([]*ast.TableName), $3.(*ast.TableName)) - } - -QuickOptional: - %prec lowerThanQuick - { - $$ = false - } -| "QUICK" - { - $$ = true - } - -/***************************Prepared Statement Start****************************** - * See: https://dev.mysql.com/doc/refman/5.7/en/prepare.html - * Example: - * PREPARE stmt_name FROM 'SELECT SQRT(POW(?,2) + POW(?,2)) AS hypotenuse'; - * OR - * SET @s = 'SELECT SQRT(POW(?,2) + POW(?,2)) AS hypotenuse'; - * PREPARE stmt_name FROM @s; - */ - -PreparedStmt: - "PREPARE" Identifier "FROM" PrepareSQL - { - var sqlText string - var sqlVar *ast.VariableExpr - switch $4.(type) { - case string: - sqlText = $4.(string) - case *ast.VariableExpr: - sqlVar = $4.(*ast.VariableExpr) - } - $$ = &ast.PrepareStmt{ - InPrepare: true, - Name: $2.(string), - SQLText: sqlText, - SQLVar: sqlVar, - } - } - -PrepareSQL: - stringLit -| UserVariable - - -/* - * See: https://dev.mysql.com/doc/refman/5.7/en/execute.html - * Example: - * EXECUTE stmt1 USING @a, @b; - * OR - * EXECUTE stmt1; - */ -ExecuteStmt: - "EXECUTE" Identifier - { - $$ = &ast.ExecuteStmt{Name: $2.(string)} - } -| "EXECUTE" Identifier "USING" UserVariableList - { - $$ = &ast.ExecuteStmt{ - Name: $2.(string), - UsingVars: $4.([]ast.ExprNode), - } - } - -UserVariableList: - UserVariable - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| UserVariableList ',' UserVariable - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) - } - -/* - * See: https://dev.mysql.com/doc/refman/5.0/en/deallocate-prepare.html - */ - -DeallocateStmt: - DeallocateSym "PREPARE" Identifier - { - $$ = &ast.DeallocateStmt{Name: $3.(string)} - } - -DeallocateSym: - "DEALLOCATE" | "DROP" - -/****************************Prepared Statement End*******************************/ - - -RollbackStmt: - "ROLLBACK" - { - $$ = &ast.RollbackStmt{} - } - -SelectStmt: - "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt - { - st := &ast.SelectStmt { - Distinct: $2.(bool), - Fields: $3.(*ast.FieldList), - LockTp: $5.(ast.SelectLockType), - } - lastField := st.Fields.Fields[len(st.Fields.Fields)-1] - if lastField.Expr != nil && lastField.AsName.O == "" { - src := yylex.(*lexer).src - var lastEnd int - if $4 != nil { - lastEnd = yyS[yypt-1].offset-1 - } else if $5 != ast.SelectLockNone { - lastEnd = yyS[yypt].offset-1 - } else { - lastEnd = len(src) - if src[lastEnd-1] == ';' { - lastEnd-- - } - } - lastField.SetText(src[lastField.Offset:lastEnd]) - } - if $4 != nil { - st.Limit = $4.(*ast.Limit) - } - $$ = st - } -| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt - { - st := &ast.SelectStmt { - Distinct: $2.(bool), - Fields: $3.(*ast.FieldList), - LockTp: $7.(ast.SelectLockType), - } - lastField := st.Fields.Fields[len(st.Fields.Fields)-1] - if lastField.Expr != nil && lastField.AsName.O == "" { - lastEnd := yyS[yypt-3].offset-1 - lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) - } - if $5 != nil { - st.Where = $5.(ast.ExprNode) - } - if $6 != nil { - st.Limit = $6.(*ast.Limit) - } - $$ = st - } -| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional - SelectStmtLimit SelectLockOpt - { - st := &ast.SelectStmt{ - Distinct: $2.(bool), - Fields: $3.(*ast.FieldList), - From: $5.(*ast.TableRefsClause), - LockTp: $11.(ast.SelectLockType), - } - - lastField := st.Fields.Fields[len(st.Fields.Fields)-1] - if lastField.Expr != nil && lastField.AsName.O == "" { - lastEnd := yyS[yypt-7].offset-1 - lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) - } - - if $6 != nil { - st.Where = $6.(ast.ExprNode) - } - - if $7 != nil { - st.GroupBy = $7.(*ast.GroupByClause) - } - - if $8 != nil { - st.Having = $8.(*ast.HavingClause) - } - - if $9 != nil { - st.OrderBy = $9.(*ast.OrderByClause) - } - - if $10 != nil { - st.Limit = $10.(*ast.Limit) - } - - $$ = st - } - -FromDual: - "FROM" "DUAL" - - -TableRefsClause: - TableRefs - { - $$ = &ast.TableRefsClause{TableRefs: $1.(*ast.Join)} - } - -TableRefs: - EscapedTableRef - { - if j, ok := $1.(*ast.Join); ok { - // if $1 is Join, use it directly - $$ = j - } else { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} - } - } -| TableRefs ',' EscapedTableRef - { - /* from a, b is default cross join */ - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} - } - -EscapedTableRef: - TableRef %prec lowerThanSetKeyword - { - $$ = $1 - } -| '{' Identifier TableRef '}' - { - /* - * ODBC escape syntax for outer join is { OJ join_table } - * Use an Identifier for OJ - */ - $$ = $3 - } - -TableRef: - TableFactor - { - $$ = $1 - } -| JoinTable - { - $$ = $1 - } - -TableFactor: - TableName TableAsNameOpt - { - $$ = &ast.TableSource{Source: $1.(*ast.TableName), AsName: $2.(model.CIStr)} - } -| '(' SelectStmt ')' TableAsName - { - st := $2.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) - $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} - } -| '(' UnionStmt ')' TableAsName - { - $$ = &ast.TableSource{Source: $2.(*ast.UnionStmt), AsName: $4.(model.CIStr)} - } -| '(' TableRefs ')' - { - $$ = $2 - } - -TableAsNameOpt: - { - $$ = model.CIStr{} - } -| TableAsName - { - $$ = $1 - } - -TableAsName: - Identifier - { - $$ = model.NewCIStr($1.(string)) - } -| "AS" Identifier - { - $$ = model.NewCIStr($2.(string)) - } - -JoinTable: - /* Use %prec to evaluate production TableRef before cross join */ - TableRef CrossOpt TableRef %prec tableRefPriority - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} - } -| TableRef CrossOpt TableRef "ON" Expression - { - on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} - } -| TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} - } - /* Support Using */ - -JoinType: - "LEFT" - { - $$ = ast.LeftJoin - } -| "RIGHT" - { - $$ = ast.RightJoin - } - -OuterOpt: - { - $$ = nil - } -| "OUTER" - - -CrossOpt: - "JOIN" -| "CROSS" "JOIN" -| "INNER" "JOIN" - - -LimitClause: - { - $$ = nil - } -| "LIMIT" LengthNum - { - $$ = &ast.Limit{Count: $2.(uint64)} - } - -SelectStmtLimit: - { - $$ = nil - } -| "LIMIT" LengthNum - { - $$ = &ast.Limit{Count: $2.(uint64)} - } -| "LIMIT" LengthNum ',' LengthNum - { - $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} - } -| "LIMIT" LengthNum "OFFSET" LengthNum - { - $$ = &ast.Limit{Offset: $4.(uint64), Count: $2.(uint64)} - } - -SelectStmtDistinct: - /* EMPTY */ - { - $$ = false - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } - -SelectStmtOpts: - SelectStmtDistinct SelectStmtCalcFoundRows - { - // TODO: return calc_found_rows opt and support more other options - $$ = $1 - } - -SelectStmtCalcFoundRows: - %prec lowerThanCalcFoundRows - { - $$ = false - } -| "SQL_CALC_FOUND_ROWS" - { - $$ = true - } - -SelectStmtFieldList: - FieldList - { - $$ = &ast.FieldList{Fields: $1.([]*ast.SelectField)} - } - -SelectStmtGroup: - /* EMPTY */ - { - $$ = nil - } -| GroupByClause - -// See: https://dev.mysql.com/doc/refman/5.7/en/subqueries.html -SubSelect: - '(' SelectStmt ')' - { - s := $2.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) - src := yylex.(*lexer).src - // See the implemention of yyParse function - s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &ast.SubqueryExpr{Query: s} - } -| '(' UnionStmt ')' - { - s := $2.(*ast.UnionStmt) - src := yylex.(*lexer).src - // See the implemention of yyParse function - s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &ast.SubqueryExpr{Query: s} - } - -// See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html -SelectLockOpt: - /* empty */ - { - $$ = ast.SelectLockNone - } -| "FOR" "UPDATE" - { - $$ = ast.SelectLockForUpdate - } -| "LOCK" "IN" "SHARE" "MODE" - { - $$ = ast.SelectLockInShareMode - } - -// See: https://dev.mysql.com/doc/refman/5.7/en/union.html -UnionStmt: - UnionClauseList "UNION" UnionOpt SelectStmt - { - union := $1.(*ast.UnionStmt) - union.Distinct = union.Distinct || $3.(bool) - lastSelect := union.Selects[len(union.Selects)-1] - yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) - union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) - $$ = union - } -| UnionClauseList "UNION" UnionOpt '(' SelectStmt ')' OrderByOptional SelectStmtLimit - { - union := $1.(*ast.UnionStmt) - union.Distinct = union.Distinct || $3.(bool) - lastSelect := union.Selects[len(union.Selects)-1] - yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) - st := $5.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) - union.Selects = append(union.Selects, st) - if $7 != nil { - union.OrderBy = $7.(*ast.OrderByClause) - } - if $8 != nil { - union.Limit = $8.(*ast.Limit) - } - $$ = union - } - -UnionClauseList: - UnionSelect - { - selects := []*ast.SelectStmt{$1.(*ast.SelectStmt)} - $$ = &ast.UnionStmt{ - Selects: selects, - } - } -| UnionClauseList "UNION" UnionOpt UnionSelect - { - union := $1.(*ast.UnionStmt) - union.Distinct = union.Distinct || $3.(bool) - lastSelect := union.Selects[len(union.Selects)-1] - yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) - union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) - $$ = union - } - -UnionSelect: - SelectStmt -| '(' SelectStmt ')' - { - st := $2.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) - $$ = st - } - -UnionOpt: - { - $$ = true - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } - - -/********************Set Statement*******************************/ -SetStmt: - "SET" VariableAssignmentList - { - $$ = &ast.SetStmt{Variables: $2.([]*ast.VariableAssignment)} - } -| "SET" "NAMES" StringName - { - $$ = &ast.SetCharsetStmt{Charset: $3.(string)} - } -| "SET" "NAMES" StringName "COLLATE" StringName - { - $$ = &ast.SetCharsetStmt{ - Charset: $3.(string), - Collate: $5.(string), - } - } -| "SET" CharsetKw StringName - { - $$ = &ast.SetCharsetStmt{Charset: $3.(string)} - } -| "SET" "PASSWORD" eq PasswordOpt - { - $$ = &ast.SetPwdStmt{Password: $4.(string)} - } -| "SET" "PASSWORD" "FOR" Username eq PasswordOpt - { - $$ = &ast.SetPwdStmt{User: $4.(string), Password: $6.(string)} - } - -VariableAssignment: - Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $1.(string), Value: $3.(ast.ExprNode), IsSystem: true} - } -| "GLOBAL" Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsGlobal: true, IsSystem: true} - } -| "SESSION" Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} - } -| "LOCAL" Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} - } -| "SYS_VAR" eq Expression - { - v := strings.ToLower($1.(string)) - var isGlobal bool - if strings.HasPrefix(v, "@@global.") { - isGlobal = true - v = strings.TrimPrefix(v, "@@global.") - } else if strings.HasPrefix(v, "@@session.") { - v = strings.TrimPrefix(v, "@@session.") - } else if strings.HasPrefix(v, "@@local.") { - v = strings.TrimPrefix(v, "@@local.") - } else if strings.HasPrefix(v, "@@") { - v = strings.TrimPrefix(v, "@@") - } - $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode), IsGlobal: isGlobal, IsSystem: true} - } -| "USER_VAR" eq Expression - { - v := $1.(string) - v = strings.TrimPrefix(v, "@") - $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode)} - } - -VariableAssignmentList: - { - $$ = []*ast.VariableAssignment{} - } -| VariableAssignment - { - $$ = []*ast.VariableAssignment{$1.(*ast.VariableAssignment)} - } -| VariableAssignmentList ',' VariableAssignment - { - $$ = append($1.([]*ast.VariableAssignment), $3.(*ast.VariableAssignment)) - } - -Variable: - SystemVariable | UserVariable - -SystemVariable: - "SYS_VAR" - { - v := strings.ToLower($1.(string)) - var isGlobal bool - if strings.HasPrefix(v, "@@global.") { - isGlobal = true - v = strings.TrimPrefix(v, "@@global.") - } else if strings.HasPrefix(v, "@@session.") { - v = strings.TrimPrefix(v, "@@session.") - } else if strings.HasPrefix(v, "@@local.") { - v = strings.TrimPrefix(v, "@@local.") - } else if strings.HasPrefix(v, "@@") { - v = strings.TrimPrefix(v, "@@") - } - $$ = &ast.VariableExpr{Name: v, IsGlobal: isGlobal, IsSystem: true} - } - -UserVariable: - "USER_VAR" - { - v := $1.(string) - v = strings.TrimPrefix(v, "@") - $$ = &ast.VariableExpr{Name: v, IsGlobal: false, IsSystem: false} - } - -Username: - stringLit "AT" stringLit - { - $$ = $1.(string) + "@" + $3.(string) - } - -PasswordOpt: - stringLit - { - $$ = $1.(string) - } -| "PASSWORD" '(' AuthString ')' - { - $$ = $3.(string) - } - -AuthString: - stringLit - { - $$ = $1.(string) - } - -/****************************Show Statement*******************************/ -ShowStmt: - "SHOW" "ENGINES" - { - $$ = &ast.ShowStmt{Tp: ast.ShowEngines} - } -| "SHOW" "DATABASES" - { - $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} - } -| "SHOW" "SCHEMAS" - { - $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} - } -| "SHOW" "CHARACTER" "SET" - { - $$ = &ast.ShowStmt{Tp: ast.ShowCharset} - } -| "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt - { - stmt := &ast.ShowStmt{ - Tp: ast.ShowTables, - DBName: $4.(string), - Full: $2.(bool), - } - if $5 != nil { - if x, ok := $5.(*ast.PatternLikeExpr); ok { - stmt.Pattern = x - } else { - stmt.Where = $5.(ast.ExprNode) - } - } - $$ = stmt - } -| "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowColumns, - Table: $4.(*ast.TableName), - DBName: $5.(string), - Full: $2.(bool), - } - } -| "SHOW" "WARNINGS" - { - $$ = &ast.ShowStmt{Tp: ast.ShowWarnings} - } -| "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt - { - stmt := &ast.ShowStmt{ - Tp: ast.ShowVariables, - GlobalScope: $2.(bool), - } - if x, ok := $4.(*ast.PatternLikeExpr); ok { - stmt.Pattern = x - } else if $4 != nil { - stmt.Where = $4.(ast.ExprNode) - } - $$ = stmt - } -| "SHOW" "COLLATION" ShowLikeOrWhereOpt - { - stmt := &ast.ShowStmt{ - Tp: ast.ShowCollation, - } - if x, ok := $3.(*ast.PatternLikeExpr); ok { - stmt.Pattern = x - } else if $3 != nil { - stmt.Where = $3.(ast.ExprNode) - } - $$ = stmt - } -| "SHOW" "CREATE" "TABLE" TableName - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowCreateTable, - Table: $4.(*ast.TableName), - } - } -| "SHOW" "GRANTS" - { - // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html - $$ = &ast.ShowStmt{Tp: ast.ShowGrants} - } -| "SHOW" "GRANTS" "FOR" Username - { - // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html - $$ = &ast.ShowStmt{ - Tp: ast.ShowGrants, - User: $4.(string), - } - } - -ShowLikeOrWhereOpt: - { - $$ = nil - } -| "LIKE" PrimaryExpression - { - $$ = &ast.PatternLikeExpr{Pattern: $2.(ast.ExprNode)} - } -| "WHERE" Expression - { - $$ = $2.(ast.ExprNode) - } - -GlobalScope: - { - $$ = false - } -| "GLOBAL" - { - $$ = true - } -| "SESSION" - { - $$ = false - } - -OptFull: - { - $$ = false - } -| "FULL" - { - $$ = true - } - -ShowDatabaseNameOpt: - { - $$ = "" - } -| "FROM" DBName - { - $$ = $2.(string) - } -| "IN" DBName - { - $$ = $2.(string) - } - -ShowTableAliasOpt: - "FROM" TableName - { - $$ = $2.(*ast.TableName) - } -| "IN" TableName - { - $$ = $2.(*ast.TableName) - } - -Statement: - EmptyStmt -| AlterTableStmt -| BeginTransactionStmt -| CommitStmt -| DeallocateStmt -| DeleteFromStmt -| ExecuteStmt -| ExplainStmt -| CreateDatabaseStmt -| CreateIndexStmt -| CreateTableStmt -| CreateUserStmt -| DoStmt -| DropDatabaseStmt -| DropIndexStmt -| DropTableStmt -| GrantStmt -| InsertIntoStmt -| PreparedStmt -| RollbackStmt -| ReplaceIntoStmt -| SelectStmt -| UnionStmt -| SetStmt -| ShowStmt -| TruncateTableStmt -| UpdateStmt -| UseStmt -| SubSelect - { - // `(select 1)`; is a valid select statement - // TODO: This is used to fix issue #320. There may be a better solution. - $$ = $1.(*ast.SubqueryExpr).Query - } - -ExplainableStmt: - SelectStmt -| DeleteFromStmt -| UpdateStmt -| InsertIntoStmt -| ReplaceIntoStmt - -StatementList: - Statement - { - if $1 != nil { - s := $1.(ast.StmtNode) - s.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = append(yylex.(*lexer).list, s) - } - } -| StatementList ';' Statement - { - if $3 != nil { - s := $3.(ast.StmtNode) - s.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = append(yylex.(*lexer).list, s) - } - } - -Constraint: - ConstraintKeywordOpt ConstraintElem - { - cst := $2.(*ast.Constraint) - if $1 != nil { - cst.Name = $1.(string) - } - $$ = cst - } - -TableElement: - ColumnDef - { - $$ = $1.(*ast.ColumnDef) - } -| Constraint - { - $$ = $1.(*ast.Constraint) - } -| "CHECK" '(' Expression ')' - { - /* Nothing to do now */ - $$ = nil - } - -TableElementList: - TableElement - { - if $1 != nil { - $$ = []interface{}{$1.(interface{})} - } else { - $$ = []interface{}{} - } - } -| TableElementList ',' TableElement - { - if $3 != nil { - $$ = append($1.([]interface{}), $3) - } else { - $$ = $1 - } - } - -TableOption: - "ENGINE" Identifier - { - $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $2.(string)} - } -| "ENGINE" eq Identifier - { - $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $3.(string)} - } -| DefaultKwdOpt CharsetKw EqOpt StringName - { - $$ = &ast.TableOption{Tp: ast.TableOptionCharset, StrValue: $4.(string)} - } -| DefaultKwdOpt "COLLATE" EqOpt StringName - { - $$ = &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: $4.(string)} - } -| "AUTO_INCREMENT" eq LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionAutoIncrement, UintValue: $3.(uint64)} - } -| "COMMENT" EqOpt stringLit - { - $$ = &ast.TableOption{Tp: ast.TableOptionComment, StrValue: $3.(string)} - } -| "AVG_ROW_LENGTH" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionAvgRowLength, UintValue: $3.(uint64)} - } -| "CONNECTION" EqOpt stringLit - { - $$ = &ast.TableOption{Tp: ast.TableOptionConnection, StrValue: $3.(string)} - } -| "CHECKSUM" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionCheckSum, UintValue: $3.(uint64)} - } -| "PASSWORD" EqOpt stringLit - { - $$ = &ast.TableOption{Tp: ast.TableOptionPassword, StrValue: $3.(string)} - } -| "COMPRESSION" EqOpt Identifier - { - $$ = &ast.TableOption{Tp: ast.TableOptionCompression, StrValue: $3.(string)} - } -| "KEY_BLOCK_SIZE" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionKeyBlockSize, UintValue: $3.(uint64)} - } -| "MAX_ROWS" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionMaxRows, UintValue: $3.(uint64)} - } -| "MIN_ROWS" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionMinRows, UintValue: $3.(uint64)} - } - - -TableOptionListOpt: - { - $$ = []*ast.TableOption{} - } -| TableOptionList %prec lowerThanComma - -TableOptionList: - TableOption - { - $$ = []*ast.TableOption{$1.(*ast.TableOption)} - } -| TableOptionList TableOption - { - $$ = append($1.([]*ast.TableOption), $2.(*ast.TableOption)) - } -| TableOptionList ',' TableOption - { - $$ = append($1.([]*ast.TableOption), $3.(*ast.TableOption)) - } - - -TruncateTableStmt: - "TRUNCATE" "TABLE" TableName - { - $$ = &ast.TruncateTableStmt{Table: $3.(*ast.TableName)} - } - -/*************************************Type Begin***************************************/ -Type: - NumericType - { - $$ = $1 - } -| StringType - { - $$ = $1 - } -| DateAndTimeType - { - $$ = $1 - } -| "float32" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "float64" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "int64" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "string" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "uint" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "uint64" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } - -NumericType: - IntegerType OptFieldLen FieldOpts - { - // TODO: check flen 0 - x := types.NewFieldType($1.(byte)) - x.Flen = $2.(int) - for _, o := range $3.([]*field.Opt) { - if o.IsUnsigned { - x.Flag |= mysql.UnsignedFlag - } - if o.IsZerofill { - x.Flag |= mysql.ZerofillFlag - } - } - $$ = x - } -| FixedPointType FloatOpt FieldOpts - { - fopt := $2.(*ast.FloatOpt) - x := types.NewFieldType($1.(byte)) - x.Flen = fopt.Flen - x.Decimal = fopt.Decimal - for _, o := range $3.([]*field.Opt) { - if o.IsUnsigned { - x.Flag |= mysql.UnsignedFlag - } - if o.IsZerofill { - x.Flag |= mysql.ZerofillFlag - } - } - $$ = x - } -| FloatingPointType FloatOpt FieldOpts - { - fopt := $2.(*ast.FloatOpt) - x := types.NewFieldType($1.(byte)) - x.Flen = fopt.Flen - if x.Tp == mysql.TypeFloat { - // Fix issue #312 - if x.Flen > 53 { - yylex.(*lexer).errf("Float len(%d) should not be greater than 53", x.Flen) - return 1 - } - if x.Flen > 24 { - x.Tp = mysql.TypeDouble - } - } - x.Decimal =fopt.Decimal - for _, o := range $3.([]*field.Opt) { - if o.IsUnsigned { - x.Flag |= mysql.UnsignedFlag - } - if o.IsZerofill { - x.Flag |= mysql.ZerofillFlag - } - } - $$ = x - } -| BitValueType OptFieldLen - { - x := types.NewFieldType($1.(byte)) - x.Flen = $2.(int) - if x.Flen == -1 || x.Flen == 0 { - x.Flen = 1 - } else if x.Flen > 64 { - yylex.(*lexer).errf("invalid field length %d for bit type, must in [1, 64]", x.Flen) - } - $$ = x - } - -IntegerType: - "TINYINT" - { - $$ = mysql.TypeTiny - } -| "SMALLINT" - { - $$ = mysql.TypeShort - } -| "MEDIUMINT" - { - $$ = mysql.TypeInt24 - } -| "INT" - { - $$ = mysql.TypeLong - } -| "INTEGER" - { - $$ = mysql.TypeLong - } -| "BIGINT" - { - $$ = mysql.TypeLonglong - } -| "BOOL" - { - $$ = mysql.TypeTiny - } -| "BOOLEAN" - { - $$ = mysql.TypeTiny - } - -OptInteger: - {} | "INTEGER" - -FixedPointType: - "DECIMAL" - { - $$ = mysql.TypeNewDecimal - } -| "NUMERIC" - { - $$ = mysql.TypeNewDecimal - } - -FloatingPointType: - "float" - { - $$ = mysql.TypeFloat - } -| "REAL" - { - $$ = mysql.TypeDouble - } -| "DOUBLE" - { - $$ = mysql.TypeDouble - } -| "DOUBLE" "PRECISION" - { - $$ = mysql.TypeDouble - } - -BitValueType: - "BIT" - { - $$ = mysql.TypeBit - } - -StringType: - NationalOpt "CHAR" FieldLen OptBinary OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $3.(int) - if $4.(bool) { - x.Flag |= mysql.BinaryFlag - } - $$ = x - } -| NationalOpt "CHAR" OptBinary OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeString) - if $3.(bool) { - x.Flag |= mysql.BinaryFlag - } - $$ = x - } -| NationalOpt "VARCHAR" FieldLen OptBinary OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeVarchar) - x.Flen = $3.(int) - if $4.(bool) { - x.Flag |= mysql.BinaryFlag - } - x.Charset = $5.(string) - x.Collate = $6.(string) - $$ = x - } -| "BINARY" OptFieldLen - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "VARBINARY" FieldLen - { - x := types.NewFieldType(mysql.TypeVarchar) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| BlobType - { - $$ = $1.(*types.FieldType) - } -| TextType OptBinary OptCharset OptCollate - { - x := $1.(*types.FieldType) - if $2.(bool) { - x.Flag |= mysql.BinaryFlag - } - x.Charset = $3.(string) - x.Collate = $4.(string) - $$ = x - } -| "ENUM" '(' StringList ')' OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeEnum) - x.Elems = $3.([]string) - x.Charset = $5.(string) - x.Collate = $6.(string) - $$ = x - } -| "SET" '(' StringList ')' OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeSet) - x.Elems = $3.([]string) - x.Charset = $5.(string) - x.Collate = $6.(string) - $$ = x - } - -NationalOpt: - { - - } -| "NATIONAL" - { - - } - -BlobType: - "TINYBLOB" - { - x := types.NewFieldType(mysql.TypeTinyBlob) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "BLOB" OptFieldLen - { - x := types.NewFieldType(mysql.TypeBlob) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "MEDIUMBLOB" - { - x := types.NewFieldType(mysql.TypeMediumBlob) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "LONGBLOB" - { - x := types.NewFieldType(mysql.TypeLongBlob) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } - -TextType: - "TINYTEXT" - { - x := types.NewFieldType(mysql.TypeTinyBlob) - $$ = x - - } -| "TEXT" OptFieldLen - { - x := types.NewFieldType(mysql.TypeBlob) - x.Flen = $2.(int) - $$ = x - } -| "MEDIUMTEXT" - { - x := types.NewFieldType(mysql.TypeMediumBlob) - $$ = x - } -| "LONGTEXT" - { - x := types.NewFieldType(mysql.TypeLongBlob) - $$ = x - } - - -DateAndTimeType: - "DATE" - { - x := types.NewFieldType(mysql.TypeDate) - $$ = x - } -| "DATETIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDatetime) - x.Decimal = $2.(int) - $$ = x - } -| "TIMESTAMP" OptFieldLen - { - x := types.NewFieldType(mysql.TypeTimestamp) - x.Decimal = $2.(int) - $$ = x - } -| "TIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDuration) - x.Decimal = $2.(int) - $$ = x - } -| "YEAR" OptFieldLen - { - x := types.NewFieldType(mysql.TypeYear) - x.Flen = $2.(int) - $$ = x - } - -FieldLen: - '(' LengthNum ')' - { - $$ = int($2.(uint64)) - } - -OptFieldLen: - { - /* -1 means unspecified field length*/ - $$ = types.UnspecifiedLength - } -| FieldLen - { - $$ = $1.(int) - } - -FieldOpt: - "UNSIGNED" - { - $$ = &field.Opt{IsUnsigned: true} - } -| "ZEROFILL" - { - $$ = &field.Opt{IsZerofill: true, IsUnsigned: true} - } - -FieldOpts: - { - $$ = []*field.Opt{} - } -| FieldOpts FieldOpt - { - $$ = append($1.([]*field.Opt), $2.(*field.Opt)) - } - -FloatOpt: - { - $$ = &ast.FloatOpt{Flen: types.UnspecifiedLength, Decimal: types.UnspecifiedLength} - } -| FieldLen - { - $$ = &ast.FloatOpt{Flen: $1.(int), Decimal: types.UnspecifiedLength} - } -| Precision - { - $$ = $1.(*ast.FloatOpt) - } - -Precision: - '(' LengthNum ',' LengthNum ')' - { - $$ = &ast.FloatOpt{Flen: int($2.(uint64)), Decimal: int($4.(uint64))} - } - -OptBinary: - { - $$ = false - } -| "BINARY" - { - $$ = true - } - -OptCharset: - { - $$ = "" - } -| CharsetKw StringName - { - $$ = $2.(string) - } - -CharsetKw: - "CHARACTER" "SET" -| "CHARSET" - -OptCollate: - { - $$ = "" - } -| "COLLATE" StringName - { - $$ = $2.(string) - } - -StringList: - stringLit - { - $$ = []string{$1.(string)} - } -| StringList ',' stringLit - { - $$ = append($1.([]string), $3.(string)) - } - -StringName: - stringLit - { - $$ = $1.(string) - } -| Identifier - { - $$ = $1.(string) - } - -/*********************************************************************************** - * Update Statement - * See: https://dev.mysql.com/doc/refman/5.7/en/update.html - ***********************************************************************************/ -UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause - { - var refs *ast.Join - if x, ok := $4.(*ast.Join); ok { - refs = x - } else { - refs = &ast.Join{Left: $4.(ast.ResultSetNode)} - } - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: refs}, - List: $6.([]*ast.Assignment), - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - if $8 != nil { - st.Order = $8.(*ast.OrderByClause) - } - if $9 != nil { - st.Limit = $9.(*ast.Limit) - } - $$ = st - if yylex.(*lexer).root { - break - } - } -| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional - { - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, - List: $6.([]*ast.Assignment), - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - $$ = st - if yylex.(*lexer).root { - break - } - } - -UseStmt: - "USE" DBName - { - $$ = &ast.UseStmt{DBName: $2.(string)} - if yylex.(*lexer).root { - break - } - } - -WhereClause: - "WHERE" Expression - { - $$ = $2.(ast.ExprNode) - } - -WhereClauseOptional: - { - $$ = nil - } -| WhereClause - { - $$ = $1 - } - -CommaOpt: - { - } -| ',' - { - } - -/************************************************************************************ - * Account Management Statements - * https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html - ************************************************************************************/ -CreateUserStmt: - "CREATE" "USER" IfNotExists UserSpecList - { - // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html - $$ = &ast.CreateUserStmt{ - IfNotExists: $3.(bool), - Specs: $4.([]*ast.UserSpec), - } - } - -UserSpec: - Username AuthOption - { - userSpec := &ast.UserSpec{ - User: $1.(string), - } - if $2 != nil { - userSpec.AuthOpt = $2.(*ast.AuthOption) - } - $$ = userSpec - } - -UserSpecList: - UserSpec - { - $$ = []*ast.UserSpec{$1.(*ast.UserSpec)} - } -| UserSpecList ',' UserSpec - { - $$ = append($1.([]*ast.UserSpec), $3.(*ast.UserSpec)) - } - -AuthOption: - { - $$ = nil - } -| "IDENTIFIED" "BY" AuthString - { - $$ = &ast.AuthOption { - AuthString: $3.(string), - ByAuthString: true, - } - } -| "IDENTIFIED" "BY" "PASSWORD" HashString - { - $$ = &ast.AuthOption{ - HashString: $4.(string), - } - } - -HashString: - stringLit - -/************************************************************************************* - * Grant statement - * See: https://dev.mysql.com/doc/refman/5.7/en/grant.html - *************************************************************************************/ -GrantStmt: - "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecList - { - $$ = &ast.GrantStmt{ - Privs: $2.([]*ast.PrivElem), - ObjectType: $4.(ast.ObjectTypeType), - Level: $5.(*ast.GrantLevel), - Users: $7.([]*ast.UserSpec), - } - } - -PrivElem: - PrivType - { - $$ = &ast.PrivElem{ - Priv: $1.(mysql.PrivilegeType), - } - } -| PrivType '(' ColumnNameList ')' - { - $$ = &ast.PrivElem{ - Priv: $1.(mysql.PrivilegeType), - Cols: $3.([]*ast.ColumnName), - } - } - -PrivElemList: - PrivElem - { - $$ = []*ast.PrivElem{$1.(*ast.PrivElem)} - } -| PrivElemList ',' PrivElem - { - $$ = append($1.([]*ast.PrivElem), $3.(*ast.PrivElem)) - } - -PrivType: - "ALL" - { - $$ = mysql.AllPriv - } -| "ALTER" - { - $$ = mysql.AlterPriv - } -| "CREATE" - { - $$ = mysql.CreatePriv - } -| "CREATE" "USER" - { - $$ = mysql.CreateUserPriv - } -| "DELETE" - { - $$ = mysql.DeletePriv - } -| "DROP" - { - $$ = mysql.DropPriv - } -| "EXECUTE" - { - $$ = mysql.ExecutePriv - } -| "INDEX" - { - $$ = mysql.IndexPriv - } -| "INSERT" - { - $$ = mysql.InsertPriv - } -| "SELECT" - { - $$ = mysql.SelectPriv - } -| "SHOW" "DATABASES" - { - $$ = mysql.ShowDBPriv - } -| "UPDATE" - { - $$ = mysql.UpdatePriv - } -| "GRANT" "OPTION" - { - $$ = mysql.GrantPriv - } - -ObjectType: - { - $$ = ast.ObjectTypeNone - } -| "TABLE" - { - $$ = ast.ObjectTypeTable - } - -PrivLevel: - '*' - { - $$ = &ast.GrantLevel { - Level: ast.GrantLevelDB, - } - } -| '*' '.' '*' - { - $$ = &ast.GrantLevel { - Level: ast.GrantLevelGlobal, - } - } -| Identifier '.' '*' - { - $$ = &ast.GrantLevel { - Level: ast.GrantLevelDB, - DBName: $1.(string), - } - } -| Identifier '.' Identifier - { - $$ = &ast.GrantLevel { - Level: ast.GrantLevelTable, - DBName: $1.(string), - TableName: $3.(string), - } - } -| Identifier - { - $$ = &ast.GrantLevel { - Level: ast.GrantLevelTable, - TableName: $1.(string), - } - } -%% diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go deleted file mode 100644 index 3e8d7bb278..0000000000 --- a/ast/parser/parser_test.go +++ /dev/null @@ -1,647 +0,0 @@ -package parser - -import ( - "fmt" - "testing" - - . "github.com/pingcap/check" - "github.com/pingcap/tidb/ast" -) - -func TestT(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testParserSuite{}) - -var _ = Suite(&testParserSuite{}) - -type testParserSuite struct { -} - -func (s *testParserSuite) TestSimple(c *C) { - // Testcase for unreserved keywords - unreservedKws := []string{ - "auto_increment", "after", "begin", "bit", "bool", "boolean", "charset", "columns", "commit", - "date", "datetime", "deallocate", "do", "end", "engine", "engines", "execute", "first", "full", - "local", "names", "offset", "password", "prepare", "quick", "rollback", "session", "signed", - "start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown", - "value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified", - "collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size", - "max_rows", "min_rows", "national", "row", "quarter", "escape", - } - for _, kw := range unreservedKws { - src := fmt.Sprintf("SELECT %s FROM tbl;", kw) - l := NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - c.Assert(l.errs, HasLen, 0, Commentf("source %s", src)) - } - - // Testcase for prepared statement - src := "SELECT id+?, id+? from t;" - l := NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - c.Assert(len(l.Stmts()), Equals, 1) - - // Testcase for -- Comment and unary -- operator - src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;" - l = NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - c.Assert(len(l.Stmts()), Equals, 2) - - // Testcase for CONVERT(expr,type) - src = "SELECT CONVERT('111', SIGNED);" - l = NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - st := l.Stmts()[0] - ss, ok := st.(*ast.SelectStmt) - c.Assert(ok, IsTrue) - c.Assert(len(ss.Fields.Fields), Equals, 1) - cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) - c.Assert(ok, IsTrue) - c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) - - // For query start with comment - srcs := []string{ - "/* some comments */ SELECT CONVERT('111', SIGNED) ;", - "/* some comments */ /*comment*/ SELECT CONVERT('111', SIGNED) ;", - "SELECT /*comment*/ CONVERT('111', SIGNED) ;", - "SELECT CONVERT('111', /*comment*/ SIGNED) ;", - "SELECT CONVERT('111', SIGNED) /*comment*/;", - } - for _, src := range srcs { - l = NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - st = l.Stmts()[0] - ss, ok = st.(*ast.SelectStmt) - c.Assert(ok, IsTrue) - } -} - -type testCase struct { - src string - ok bool -} - -func (s *testParserSuite) RunTest(c *C, table []testCase) { - for _, t := range table { - l := NewLexer(t.src) - ok := yyParse(l) == 0 - c.Assert(ok, Equals, t.ok, Commentf("source %v %v", t.src, l.errs)) - switch ok { - case true: - c.Assert(l.errs, HasLen, 0, Commentf("src: %s", t.src)) - case false: - c.Assert(len(l.errs), Not(Equals), 0, Commentf("src: %s", t.src)) - } - } -} -func (s *testParserSuite) TestDMLStmt(c *C) { - table := []testCase{ - {"", true}, - {";", true}, - {"INSERT INTO foo VALUES (1234)", true}, - {"INSERT INTO foo VALUES (1234, 5678)", true}, - // 15 - {"INSERT INTO foo VALUES (1 || 2)", true}, - {"INSERT INTO foo VALUES (1 | 2)", true}, - {"INSERT INTO foo VALUES (false || true)", true}, - {"INSERT INTO foo VALUES (bar(5678))", false}, - // 20 - {"INSERT INTO foo VALUES ()", true}, - {"SELECT * FROM t", true}, - {"SELECT * FROM t AS u", true}, - // 25 - {"SELECT * FROM t, v", true}, - {"SELECT * FROM t AS u, v", true}, - {"SELECT * FROM t, v AS w", true}, - {"SELECT * FROM t AS u, v AS w", true}, - {"SELECT * FROM foo, bar, foo", true}, - // 30 - {"SELECT DISTINCTS * FROM t", false}, - {"SELECT DISTINCT * FROM t", true}, - {"INSERT INTO foo (a) VALUES (42)", true}, - {"INSERT INTO foo (a,) VALUES (42,)", false}, - // 35 - {"INSERT INTO foo (a,b) VALUES (42,314)", true}, - {"INSERT INTO foo (a,b,) VALUES (42,314)", false}, - {"INSERT INTO foo (a,b,) VALUES (42,314,)", false}, - {"INSERT INTO foo () VALUES ()", true}, - {"INSERT INTO foo VALUE ()", true}, - - {"REPLACE INTO foo VALUES (1 || 2)", true}, - {"REPLACE INTO foo VALUES (1 | 2)", true}, - {"REPLACE INTO foo VALUES (false || true)", true}, - {"REPLACE INTO foo VALUES (bar(5678))", false}, - {"REPLACE INTO foo VALUES ()", true}, - {"REPLACE INTO foo (a,b) VALUES (42,314)", true}, - {"REPLACE INTO foo (a,b,) VALUES (42,314)", false}, - {"REPLACE INTO foo (a,b,) VALUES (42,314,)", false}, - {"REPLACE INTO foo () VALUES ()", true}, - {"REPLACE INTO foo VALUE ()", true}, - // 40 - {`SELECT stuff.id - FROM stuff - WHERE stuff.value >= ALL (SELECT stuff.value - FROM stuff)`, true}, - {"BEGIN", true}, - {"START TRANSACTION", true}, - // 45 - {"COMMIT", true}, - {"ROLLBACK", true}, - {` - BEGIN; - INSERT INTO foo VALUES (42, 3.14); - INSERT INTO foo VALUES (-1, 2.78); - COMMIT;`, true}, - {` // A - BEGIN; - INSERT INTO tmp SELECT * from bar; - SELECT * from tmp; - - // B - ROLLBACK;`, true}, - - // set - // user defined - {"SET @a = 1", true}, - // session system variables - {"SET SESSION autocommit = 1", true}, - {"SET @@session.autocommit = 1", true}, - {"SET LOCAL autocommit = 1", true}, - {"SET @@local.autocommit = 1", true}, - {"SET @@autocommit = 1", true}, - {"SET autocommit = 1", true}, - // global system variables - {"SET GLOBAL autocommit = 1", true}, - {"SET @@global.autocommit = 1", true}, - // SET CHARACTER SET - {"SET CHARACTER SET utf8mb4;", true}, - {"SET CHARACTER SET 'utf8mb4';", true}, - // Set password - {"SET PASSWORD = 'password';", true}, - {"SET PASSWORD FOR 'root'@'localhost' = 'password';", true}, - - // qualified select - {"SELECT a.b.c FROM t", true}, - {"SELECT a.b.*.c FROM t", false}, - {"SELECT a.b.* FROM t", true}, - {"SELECT a FROM t", true}, - {"SELECT a.b.c.d FROM t", false}, - - // Do statement - {"DO 1", true}, - {"DO 1 from t", false}, - - // Select for update - {"SELECT * from t for update", true}, - {"SELECT * from t lock in share mode", true}, - - // For alter table - {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED", true}, - {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED FIRST", true}, - {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED AFTER b", true}, - - // from join - {"SELECT * from t1, t2, t3", true}, - {"select * from t1 join t2 left join t3 on t2.id = t3.id", true}, - {"select * from t1 right join t2 on t1.id = t2.id left join t3 on t3.id = t2.id", true}, - {"select * from t1 right join t2 on t1.id = t2.id left join t3", false}, - - // For show full columns - {"show columns in t;", true}, - {"show full columns in t;", true}, - - // For set names - {"set names utf8", true}, - {"set names utf8 collate utf8_unicode_ci", true}, - - // For show character set - {"show character set;", true}, - // For on duplicate key update - {"INSERT INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, - {"INSERT IGNORE INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, - - // For SHOW statement - {"SHOW VARIABLES LIKE 'character_set_results'", true}, - {"SHOW GLOBAL VARIABLES LIKE 'character_set_results'", true}, - {"SHOW SESSION VARIABLES LIKE 'character_set_results'", true}, - {"SHOW VARIABLES", true}, - {"SHOW GLOBAL VARIABLES", true}, - {"SHOW GLOBAL VARIABLES WHERE Variable_name = 'autocommit'", true}, - {`SHOW FULL TABLES FROM icar_qa LIKE play_evolutions`, true}, - {`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true}, - {`SHOW GRANTS`, true}, - {`SHOW GRANTS FOR 'test'@'localhost'`, true}, - - // For default value - {"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true}, - - // For delete statement - {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, - {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, - {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, - - // For update statement - {"UPDATE t SET id = id + 1 ORDER BY id DESC;", true}, - {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id;", true}, - {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id LIMIT 10;", false}, - {"UPDATE user T0 LEFT OUTER JOIN user_profile T1 ON T1.id = T0.profile_id SET T0.profile_id = 1 WHERE T0.profile_id IN (1);", true}, - - // For select with where clause - {"SELECT * FROM t WHERE 1 = 1", true}, - - // For show collation - {"show collation", true}, - {"show collation like 'utf8%'", true}, - {"show collation where Charset = 'utf8' and Collation = 'utf8_bin'", true}, - - // For dual - {"select 1 from dual", true}, - {"select 1 from dual limit 1", true}, - {"select 1 where exists (select 2)", false}, - {"select 1 from dual where not exists (select 2)", true}, - - // For show create table - {"show create table test.t", true}, - {"show create table t", true}, - - // For https://github.com/pingcap/tidb/issues/320 - {`(select 1);`, true}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestExpression(c *C) { - table := []testCase{ - // Sign expression - {"SELECT ++1", true}, - {"SELECT -*1", false}, - {"SELECT -+1", true}, - {"SELECT -1", true}, - {"SELECT --1", true}, - - // For string literal - {`select '''a''', """a"""`, true}, - {`select ''a''`, false}, - {`select ""a""`, false}, - {`select '''a''';`, true}, - {`select '\'a\'';`, true}, - {`select "\"a\"";`, true}, - {`select """a""";`, true}, - {`select _utf8"string";`, true}, - // For comparison - {"select 1 <=> 0, 1 <=> null, 1 = null", true}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestBuiltin(c *C) { - table := []testCase{ - // For buildin functions - {"SELECT DAYOFMONTH('2007-02-03');", true}, - {"SELECT RAND();", true}, - {"SELECT RAND(1);", true}, - - {"SELECT SUBSTRING('Quadratically',5);", true}, - {"SELECT SUBSTRING('Quadratically',5, 3);", true}, - {"SELECT SUBSTRING('Quadratically' FROM 5);", true}, - {"SELECT SUBSTRING('Quadratically' FROM 5 FOR 3);", true}, - - {"SELECT CONVERT('111', SIGNED);", true}, - - {"SELECT DATABASE();", true}, - {"SELECT USER();", true}, - {"SELECT CURRENT_USER();", true}, - {"SELECT CURRENT_USER;", true}, - - {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', 2);", true}, - {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', -2);", true}, - - {`SELECT LOWER("A"), UPPER("a")`, true}, - - {`SELECT REPLACE('www.mysql.com', 'w', 'Ww')`, true}, - - {`SELECT LOCATE('bar', 'foobarbar');`, true}, - {`SELECT LOCATE('bar', 'foobarbar', 5);`, true}, - - {"select current_date, current_date(), curdate()", true}, - - // For delete statement - {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, - {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, - {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, - - // For time fsp - {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, - - // For row - {"select row(1)", false}, - {"select row(1, 1,)", false}, - {"select (1, 1,)", false}, - {"select row(1, 1) > row(1, 1), row(1, 1, 1) > row(1, 1, 1)", true}, - {"Select (1, 1) > (1, 1)", true}, - {"create table t (row int)", true}, - - // For cast with charset - {"SELECT *, CAST(data AS CHAR CHARACTER SET utf8) FROM t;", true}, - - // For binary operator - {"SELECT binary 'a';", true}, - - // Select time - {"select current_timestamp", true}, - {"select current_timestamp()", true}, - {"select current_timestamp(6)", true}, - {"select now()", true}, - {"select now(6)", true}, - {"select sysdate(), sysdate(6)", true}, - - // For time extract - {`select extract(microsecond from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(second from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(minute from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(hour from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(day from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(week from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(month from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(quarter from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(year from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(second_microsecond from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(minute_microsecond from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(minute_second from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(hour_microsecond from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(hour_second from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(hour_minute from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(day_microsecond from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(day_second from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(day_minute from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(day_hour from "2011-11-11 10:10:10.123456")`, true}, - {`select extract(year_month from "2011-11-11 10:10:10.123456")`, true}, - - // For issue 224 - {`SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8) COLLATE utf8_bin;`, true}, - - // For trim - {`SELECT TRIM(' bar ');`, true}, - {`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true}, - {`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true}, - {`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true}, - - // For date_add - {`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, - {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, - - // For date_sub - {`select date_sub("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 10 second)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 10 day)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 1 week)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 1 month)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval 1 year)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, - {`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestIdentifier(c *C) { - table := []testCase{ - // For quote identifier - {"select `a`, `a.b`, `a b` from t", true}, - // For unquoted identifier - {"create table MergeContextTest$Simple (value integer not null, primary key (value))", true}, - // For as - {"select 1 as a, 1 as `a`, 1 as \"a\", 1 as 'a'", true}, - {`select 1 as a, 1 as "a", 1 as 'a'`, true}, - {`select 1 a, 1 "a", 1 'a'`, true}, - {`select * from t as "a"`, false}, - {`select * from t a`, true}, - {`select * from t as a`, true}, - {"select 1 full, 1 row, 1 abs", true}, - {"select * from t full, t1 row, t2 abs", true}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestDDL(c *C) { - table := []testCase{ - {"CREATE", false}, - {"CREATE TABLE", false}, - {"CREATE TABLE foo (", false}, - {"CREATE TABLE foo ()", false}, - {"CREATE TABLE foo ();", false}, - {"CREATE TABLE foo (a TINYINT UNSIGNED);", true}, - {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, - {"CREATE TABLE foo (a bigint unsigned, b bool);", true}, - {"CREATE TABLE foo (a TINYINT, b SMALLINT) CREATE TABLE bar (x INT, y int64)", false}, - {"CREATE TABLE foo (a int, b float); CREATE TABLE bar (x double, y float)", true}, - {"CREATE TABLE foo (a bytes)", false}, - {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, - {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) -- foo", true}, - {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) // foo", true}, - {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, - {"CREATE TABLE foo /* foo */ (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, - {"CREATE TABLE foo (name CHAR(50) BINARY)", true}, - {"CREATE TABLE foo (name CHAR(50) COLLATE utf8_bin)", true}, - {"CREATE TABLE foo (name CHAR(50) CHARACTER SET utf8)", true}, - {"CREATE TABLE foo (name CHAR(50) BINARY CHARACTER SET utf8 COLLATE utf8_bin)", true}, - - {"CREATE TABLE foo (a.b, b);", false}, - {"CREATE TABLE foo (a, b.c);", false}, - // For table option - {"create table t (c int) avg_row_length = 3", true}, - {"create table t (c int) avg_row_length 3", true}, - {"create table t (c int) checksum = 0", true}, - {"create table t (c int) checksum 1", true}, - {"create table t (c int) compression = none", true}, - {"create table t (c int) compression lz4", true}, - {"create table t (c int) connection = 'abc'", true}, - {"create table t (c int) connection 'abc'", true}, - {"create table t (c int) key_block_size = 1024", true}, - {"create table t (c int) key_block_size 1024", true}, - {"create table t (c int) max_rows = 1000", true}, - {"create table t (c int) max_rows 1000", true}, - {"create table t (c int) min_rows = 1000", true}, - {"create table t (c int) min_rows 1000", true}, - {"create table t (c int) password = 'abc'", true}, - {"create table t (c int) password 'abc'", true}, - // For check clause - {"create table t (c1 bool, c2 bool, check (c1 in (0, 1)), check (c2 in (0, 1)))", true}, - {"CREATE TABLE Customer (SD integer CHECK (SD > 0), First_Name varchar(30));", true}, - - {"create database xxx", true}, - {"create database if exists xxx", false}, - {"create database if not exists xxx", true}, - {"create schema xxx", true}, - {"create schema if exists xxx", false}, - {"create schema if not exists xxx", true}, - // For drop datbase/schema/table - {"drop database xxx", true}, - {"drop database if exists xxx", true}, - {"drop database if not exists xxx", false}, - {"drop schema xxx", true}, - {"drop schema if exists xxx", true}, - {"drop schema if not exists xxx", false}, - {"drop table xxx", true}, - {"drop table xxx, yyy", true}, - {"drop tables xxx", true}, - {"drop tables xxx, yyy", true}, - {"drop table if exists xxx", true}, - {"drop table if not exists xxx", false}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestType(c *C) { - table := []testCase{ - // For time fsp - {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, - - // For hexadecimal - {"SELECT x'0a', X'11', 0x11", true}, - {"select x'0xaa'", false}, - {"select 0X11", false}, - - // For bit - {"select 0b01, 0b0, b'11', B'11'", true}, - {"select 0B01", false}, - {"select 0b21", false}, - - // For enum and set type - {"create table t (c1 enum('a', 'b'), c2 set('a', 'b'))", true}, - {"create table t (c1 enum)", false}, - {"create table t (c1 set)", false}, - - // For blob and text field length - {"create table t (c1 blob(1024), c2 text(1024))", true}, - - // For year - {"create table t (y year(4), y1 year)", true}, - - // For national - {"create table t (c1 national char(2), c2 national varchar(2))", true}, - - // For https://github.com/pingcap/tidb/issues/312 - {`create table t (c float(53));`, true}, - {`create table t (c float(54));`, false}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestPrivilege(c *C) { - table := []testCase{ - // For create user - {`CREATE USER IF NOT EXISTS 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, - {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, - {`CREATE USER 'root'@'localhost' IDENTIFIED BY PASSWORD 'hashstring'`, true}, - {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password', 'root'@'127.0.0.1' IDENTIFIED BY PASSWORD 'hashstring'`, true}, - - // For grant statement - {"GRANT ALL ON db1.* TO 'jeffrey'@'localhost';", true}, - {"GRANT SELECT ON db2.invoice TO 'jeffrey'@'localhost';", true}, - {"GRANT ALL ON *.* TO 'someuser'@'somehost';", true}, - {"GRANT SELECT, INSERT ON *.* TO 'someuser'@'somehost';", true}, - {"GRANT ALL ON mydb.* TO 'someuser'@'somehost';", true}, - {"GRANT SELECT, INSERT ON mydb.* TO 'someuser'@'somehost';", true}, - {"GRANT ALL ON mydb.mytbl TO 'someuser'@'somehost';", true}, - {"GRANT SELECT, INSERT ON mydb.mytbl TO 'someuser'@'somehost';", true}, - {"GRANT SELECT (col1), INSERT (col1,col2) ON mydb.mytbl TO 'someuser'@'somehost';", true}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestComment(c *C) { - table := []testCase{ - {"create table t (c int comment 'comment')", true}, - {"create table t (c int) comment = 'comment'", true}, - {"create table t (c int) comment 'comment'", true}, - {"create table t (c int) comment comment", false}, - {"create table t (comment text)", true}, - // For comment in query - {"/*comment*/ /*comment*/ select c /* this is a comment */ from t;", true}, - } - s.RunTest(c, table) -} -func (s *testParserSuite) TestSubquery(c *C) { - table := []testCase{ - // For compare subquery - {"SELECT 1 > (select 1)", true}, - {"SELECT 1 > ANY (select 1)", true}, - {"SELECT 1 > ALL (select 1)", true}, - {"SELECT 1 > SOME (select 1)", true}, - - // For exists subquery - {"SELECT EXISTS select 1", false}, - {"SELECT EXISTS (select 1)", true}, - {"SELECT + EXISTS (select 1)", true}, - {"SELECT - EXISTS (select 1)", true}, - {"SELECT NOT EXISTS (select 1)", true}, - {"SELECT + NOT EXISTS (select 1)", false}, - {"SELECT - NOT EXISTS (select 1)", false}, - } - s.RunTest(c, table) -} -func (s *testParserSuite) TestUnion(c *C) { - table := []testCase{ - {"select c1 from t1 union select c2 from t2", true}, - {"select c1 from t1 union (select c2 from t2)", true}, - {"select c1 from t1 union (select c2 from t2) order by c1", true}, - {"select c1 from t1 union select c2 from t2 order by c2", true}, - {"select c1 from t1 union (select c2 from t2) limit 1", true}, - {"select c1 from t1 union (select c2 from t2) limit 1, 1", true}, - {"select c1 from t1 union (select c2 from t2) order by c1 limit 1", true}, - {"(select c1 from t1) union distinct select c2 from t2", true}, - {"(select c1 from t1) union all select c2 from t2", true}, - {"(select c1 from t1) union (select c2 from t2) order by c1 union select c3 from t3", false}, - {"(select c1 from t1) union (select c2 from t2) limit 1 union select c3 from t3", false}, - {"(select c1 from t1) union select c2 from t2 union (select c3 from t3) order by c1 limit 1", true}, - {"select (select 1 union select 1) as a", true}, - {"select * from (select 1 union select 2) as a", true}, - {"insert into t select c1 from t1 union select c2 from t2", true}, - {"insert into t (c) select c1 from t1 union select c2 from t2", true}, - } - s.RunTest(c, table) -} - -func (s *testParserSuite) TestLikeEscape(c *C) { - table := []testCase{ - // For like escape - {`select "abc_" like "abc\\_" escape ''`, true}, - {`select "abc_" like "abc\\_" escape '\\'`, true}, - {`select "abc_" like "abc\\_" escape '||'`, false}, - {`select "abc" like "escape" escape '+'`, true}, - } - - s.RunTest(c, table) -} diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l deleted file mode 100644 index 9ee7b1b8a9..0000000000 --- a/ast/parser/scanner.l +++ /dev/null @@ -1,1138 +0,0 @@ -%{ -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package parser - -import ( - "fmt" - "math" - "strconv" - "strings" - "unicode" - - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/util/charset" - "github.com/pingcap/tidb/util/stringutil" -) - -type lexer struct { - c int - col int - errs []error - expr ast.ExprNode - i int - inj int - lcol int - line int - list []ast.StmtNode - ncol int - nline int - sc int - src string - val []byte - ungetBuf []byte - root bool - prepare bool - stmtStartPos int - stringLit []byte - - // record token's offset of the input - tokenEndOffset int - tokenStartOffset int - - // Charset information - charset string - collation string -} - - -// NewLexer builds a new lexer. -func NewLexer(src string) (l *lexer) { - l = &lexer{ - src: src, - nline: 1, - ncol: 0, - } - l.next() - return -} - -func (l *lexer) Errors() []error { - return l.errs -} - -func (l *lexer) Stmts() []ast.StmtNode { - return l.list -} - -func (l *lexer) Expr() ast.ExprNode { - return l.expr -} - -func (l *lexer) Inj() int { - return l.inj -} - -func (l *lexer) SetInj(inj int) { - l.inj = inj -} - -func (l *lexer) SetPrepare() { - l.prepare = true -} - -func (l *lexer) IsPrepare() bool { - return l.prepare -} - -func (l *lexer) Root() bool { - return l.root -} - -func (l *lexer) SetRoot(root bool) { - l.root = root -} - -func (l *lexer) SetCharsetInfo(charset, collation string) { - l.charset = charset - l.collation = collation -} - -func (l *lexer) GetCharsetInfo() (string, string) { - return l.charset, l.collation -} - -// The select statement is not at the end of the whole statement, if the last -// field text was set from its offset to the end of the src string, update -// the last field text. -func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { - lastField := st.Fields.Fields[len(st.Fields.Fields)-1] - if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 { - lastField.SetText(l.src[lastField.Offset:lastEnd]) - } -} - -func (l *lexer) unget(b byte) { - l.ungetBuf = append(l.ungetBuf, b) - l.i-- - l.ncol-- - l.tokenEndOffset-- -} - -func (l *lexer) next() int { - if un := len(l.ungetBuf); un > 0 { - nc := l.ungetBuf[0] - l.ungetBuf = l.ungetBuf[1:] - l.c = int(nc) - return l.c - } - - if l.c != 0 { - l.val = append(l.val, byte(l.c)) - } - l.c = 0 - if l.i < len(l.src) { - l.c = int(l.src[l.i]) - l.i++ - } - switch l.c { - case '\n': - l.lcol = l.ncol - l.nline++ - l.ncol = 0 - default: - l.ncol++ - } - l.tokenEndOffset++ - return l.c -} - -func (l *lexer) err0(ln, c int, arg interface{}) { - var argStr string - if arg != nil { - argStr = fmt.Sprintf(" %v", arg) - } - - err := fmt.Errorf("line %d column %d near \"%s\"%s", ln, c, l.val, argStr) - l.errs = append(l.errs, err) -} - -func (l *lexer) err(arg interface{}) { - l.err0(l.line, l.col, arg) -} - -func (l *lexer) errf(format string, args ...interface{}) { - s := fmt.Sprintf(format, args...) - l.err0(l.line, l.col, s) -} - -func (l *lexer) Error(s string) { - // Notice: ignore origin error info. - l.err(nil) -} - -func (l *lexer) stmtText() string { - endPos := l.i - if l.src[l.i-1] == '\n' { - endPos = l.i-1 // trim new line - } - if l.src[l.stmtStartPos] == '\n' { - l.stmtStartPos++ - } - - text := l.src[l.stmtStartPos:endPos] - - l.stmtStartPos = l.i - return text -} - - -func (l *lexer) Lex(lval *yySymType) (r int) { - defer func() { - lval.line, lval.col, lval.offset = l.line, l.col, l.tokenStartOffset - l.tokenStartOffset = l.tokenEndOffset - }() - const ( - INITIAL = iota - S1 - S2 - S3 - S4 - ) - - if n := l.inj; n != 0 { - l.inj = 0 - return n - } - - c0, c := 0, l.c -%} - -int_lit {decimal_lit}|{octal_lit} -decimal_lit [1-9][0-9]* -octal_lit 0[0-7]* -hex_lit 0[xX][0-9a-fA-F]+|[xX]"'"[0-9a-fA-F]+"'" -bit_lit 0[bB][01]+|[bB]"'"[01]+"'" - -float_lit {D}"."{D}?{E}?|{D}{E}|"."{D}{E}? -D [0-9]+ -E [eE][-+]?[0-9]+ - -imaginary_ilit {D}i -imaginary_lit {float_lit}i - -a [aA] -b [bB] -c [cC] -d [dD] -e [eE] -f [fF] -g [gG] -h [hH] -i [iI] -j [jJ] -k [kK] -l [lL] -m [mM] -n [nN] -o [oO] -p [pP] -q [qQ] -r [rR] -s [sS] -t [tT] -u [uU] -v [vV] -w [wW] -x [xX] -y [yY] -z [zZ] - -abs {a}{b}{s} -add {a}{d}{d} -after {a}{f}{t}{e}{r} -all {a}{l}{l} -alter {a}{l}{t}{e}{r} -and {a}{n}{d} -any {a}{n}{y} -as {a}{s} -asc {a}{s}{c} -auto_increment {a}{u}{t}{o}_{i}{n}{c}{r}{e}{m}{e}{n}{t} -avg {a}{v}{g} -avg_row_length {a}{v}{g}_{r}{o}{w}_{l}{e}{n}{g}{t}{h} -begin {b}{e}{g}{i}{n} -between {b}{e}{t}{w}{e}{e}{n} -both {b}{o}{t}{h} -by {b}{y} -case {c}{a}{s}{e} -cast {c}{a}{s}{t} -character {c}{h}{a}{r}{a}{c}{t}{e}{r} -charset {c}{h}{a}{r}{s}{e}{t} -check {c}{h}{e}{c}{k} -checksum {c}{h}{e}{c}{k}{s}{u}{m} -coalesce {c}{o}{a}{l}{e}{s}{c}{e} -collate {c}{o}{l}{l}{a}{t}{e} -collation {c}{o}{l}{l}{a}{t}{i}{o}{n} -column {c}{o}{l}{u}{m}{n} -columns {c}{o}{l}{u}{m}{n}{s} -comment {c}{o}{m}{m}{e}{n}{t} -commit {c}{o}{m}{m}{i}{t} -compression {c}{o}{m}{p}{r}{e}{s}{s}{i}{o}{n} -concat {c}{o}{n}{c}{a}{t} -concat_ws {c}{o}{n}{c}{a}{t}_{w}{s} -connection {c}{o}{n}{n}{e}{c}{t}{i}{o}{n} -constraint {c}{o}{n}{s}{t}{r}{a}{i}{n}{t} -convert {c}{o}{n}{v}{e}{r}{t} -count {c}{o}{u}{n}{t} -create {c}{r}{e}{a}{t}{e} -cross {c}{r}{o}{s}{s} -curdate {c}{u}{r}{d}{a}{t}{e} -current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} -current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} -database {d}{a}{t}{a}{b}{a}{s}{e} -databases {d}{a}{t}{a}{b}{a}{s}{e}{s} -date_add {d}{a}{t}{e}_{a}{d}{d} -date_sub {d}{a}{t}{e}_{s}{u}{b} -day {d}{a}{y} -dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} -dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} -dayofyear {d}{a}{y}{o}{f}{y}{e}{a}{r} -deallocate {d}{e}{a}{l}{l}{o}{c}{a}{t}{e} -default {d}{e}{f}{a}{u}{l}{t} -delayed {d}{e}{l}{a}{y}{e}{d} -delete {d}{e}{l}{e}{t}{e} -drop {d}{r}{o}{p} -desc {d}{e}{s}{c} -describe {d}{e}{s}{c}{r}{i}{b}{e} -distinct {d}{i}{s}{t}{i}{n}{c}{t} -div {d}{i}{v} -do {d}{o} -dual {d}{u}{a}{l} -duplicate {d}{u}{p}{l}{i}{c}{a}{t}{e} -else {e}{l}{s}{e} -end {e}{n}{d} -engine {e}{n}{g}{i}{n}{e} -engines {e}{n}{g}{i}{n}{e}{s} -escape {e}{s}{c}{a}{p}{e} -execute {e}{x}{e}{c}{u}{t}{e} -exists {e}{x}{i}{s}{t}{s} -explain {e}{x}{p}{l}{a}{i}{n} -extract {e}{x}{t}{r}{a}{c}{t} -first {f}{i}{r}{s}{t} -for {f}{o}{r} -foreign {f}{o}{r}{e}{i}{g}{n} -found_rows {f}{o}{u}{n}{d}_{r}{o}{w}{s} -from {f}{r}{o}{m} -full {f}{u}{l}{l} -fulltext {f}{u}{l}{l}{t}{e}{x}{t} -global {g}{l}{o}{b}{a}{l} -grant {g}{r}{a}{n}{t} -grants {g}{r}{a}{n}{t}{s} -group {g}{r}{o}{u}{p} -group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} -having {h}{a}{v}{i}{n}{g} -high_priority {h}{i}{g}{h}_{p}{r}{i}{o}{r}{i}{t}{y} -hour {h}{o}{u}{r} -identified {i}{d}{e}{n}{t}{i}{f}{i}{e}{d} -if {i}{f} -ifnull {i}{f}{n}{u}{l}{l} -ignore {i}{g}{n}{o}{r}{e} -in {i}{n} -index {i}{n}{d}{e}{x} -inner {i}{n}{n}{e}{r} -insert {i}{n}{s}{e}{r}{t} -interval {i}{n}{t}{e}{r}{v}{a}{l} -into {i}{n}{t}{o} -is {i}{s} -join {j}{o}{i}{n} -key {k}{e}{y} -key_block_size {k}{e}{y}_{b}{l}{o}{c}{k}_{s}{i}{z}{e} -leading {l}{e}{a}{d}{i}{n}{g} -left {l}{e}{f}{t} -length {l}{e}{n}{g}{t}{h} -like {l}{i}{k}{e} -limit {l}{i}{m}{i}{t} -local {l}{o}{c}{a}{l} -locate {l}{o}{c}{a}{t}{e} -lock {l}{o}{c}{k} -lower {l}{o}{w}{e}{r} -low_priority {l}{o}{w}_{p}{r}{i}{o}{r}{i}{t}{y} -max_rows {m}{a}{x}_{r}{o}{w}{s} -microsecond {m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -minute {m}{i}{n}{u}{t}{e} -min_rows {m}{i}{n}_{r}{o}{w}{s} -mod {m}{o}{d} -mode {m}{o}{d}{e} -month {m}{o}{n}{t}{h} -names {n}{a}{m}{e}{s} -national {n}{a}{t}{i}{o}{n}{a}{l} -not {n}{o}{t} -offset {o}{f}{f}{s}{e}{t} -on {o}{n} -option {o}{p}{t}{i}{o}{n} -or {o}{r} -order {o}{r}{d}{e}{r} -outer {o}{u}{t}{e}{r} -password {p}{a}{s}{s}{w}{o}{r}{d} -prepare {p}{r}{e}{p}{a}{r}{e} -primary {p}{r}{i}{m}{a}{r}{y} -quarter {q}{u}{a}{r}{t}{e}{r} -quick {q}{u}{i}{c}{k} -rand {r}{a}{n}{d} -repeat {r}{e}{p}{e}{a}{t} -references {r}{e}{f}{e}{r}{e}{n}{c}{e}{s} -regexp {r}{e}{g}{e}{x}{p} -replace {r}{e}{p}{l}{a}{c}{e} -right {r}{i}{g}{h}{t} -rlike {r}{l}{i}{k}{e} -rollback {r}{o}{l}{l}{b}{a}{c}{k} -row {r}{o}{w} -schema {s}{c}{h}{e}{m}{a} -schemas {s}{c}{h}{e}{m}{a}{s} -second {s}{e}{c}{o}{n}{d} -select {s}{e}{l}{e}{c}{t} -session {s}{e}{s}{s}{i}{o}{n} -set {s}{e}{t} -share {s}{h}{a}{r}{e} -show {s}{h}{o}{w} -some {s}{o}{m}{e} -start {s}{t}{a}{r}{t} -substring {s}{u}{b}{s}{t}{r}{i}{n}{g} -substring_index {s}{u}{b}{s}{t}{r}{i}{n}{g}_{i}{n}{d}{e}{x} -sum {s}{u}{m} -sysdate {s}{y}{s}{d}{a}{t}{e} -table {t}{a}{b}{l}{e} -tables {t}{a}{b}{l}{e}{s} -then {t}{h}{e}{n} -to {t}{o} -trailing {t}{r}{a}{i}{l}{i}{n}{g} -transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n} -trim {t}{r}{i}{m} -truncate {t}{r}{u}{n}{c}{a}{t}{e} -max {m}{a}{x} -min {m}{i}{n} -unknown {u}{n}{k}{n}{o}{w}{n} -union {u}{n}{i}{o}{n} -unique {u}{n}{i}{q}{u}{e} -nullif {n}{u}{l}{l}{i}{f} -update {u}{p}{d}{a}{t}{e} -upper {u}{p}{p}{e}{r} -value {v}{a}{l}{u}{e} -values {v}{a}{l}{u}{e}{s} -variables {v}{a}{r}{i}{a}{b}{l}{e}{s} -warnings {w}{a}{r}{n}{i}{n}{g}{s} -week {w}{e}{e}{k} -weekday {w}{e}{e}{k}{d}{a}{y} -weekofyear {w}{e}{e}{k}{o}{f}{y}{e}{a}{r} -where {w}{h}{e}{r}{e} -when {w}{h}{e}{n} -xor {x}{o}{r} -yearweek {y}{e}{a}{r}{w}{e}{e}{k} - -null {n}{u}{l}{l} -false {f}{a}{l}{s}{e} -true {t}{r}{u}{e} - -calc_found_rows {s}{q}{l}_{c}{a}{l}{c}_{f}{o}{u}{n}{d}_{r}{o}{w}{s} - -current_ts {c}{u}{r}{r}{e}{n}{t}_{t}{i}{m}{e}{s}{t}{a}{m}{p} -localtime {l}{o}{c}{a}{l}{t}{i}{m}{e} -localts {l}{o}{c}{a}{l}{t}{i}{m}{e}{s}{t}{a}{m}{p} -now {n}{o}{w} - -bit {b}{i}{t} -tiny {t}{i}{n}{y} -tinyint {t}{i}{n}{y}{i}{n}{t} -smallint {s}{m}{a}{l}{l}{i}{n}{t} -mediumint {m}{e}{d}{i}{u}{m}{i}{n}{t} -int {i}{n}{t} -integer {i}{n}{t}{e}{g}{e}{r} -bigint {b}{i}{g}{i}{n}{t} -real {r}{e}{a}{l} -double {d}{o}{u}{b}{l}{e} -float {f}{l}{o}{a}{t} -decimal {d}{e}{c}{i}{m}{a}{l} -numeric {n}{u}{m}{e}{r}{i}{c} -date {d}{a}{t}{e} -time {t}{i}{m}{e} -timestamp {t}{i}{m}{e}{s}{t}{a}{m}{p} -datetime {d}{a}{t}{e}{t}{i}{m}{e} -year {y}{e}{a}{r} -char {c}{h}{a}{r} -varchar {v}{a}{r}{c}{h}{a}{r} -binary {b}{i}{n}{a}{r}{y} -varbinary {v}{a}{r}{b}{i}{n}{a}{r}{y} -tinyblob {t}{i}{n}{y}{b}{l}{o}{b} -blob {b}{l}{o}{b} -mediumblob {m}{e}{d}{i}{u}{m}{b}{l}{o}{b} -longblob {l}{o}{n}{g}{b}{l}{o}{b} -tinytext {t}{i}{n}{y}{t}{e}{x}{t} -text {t}{e}{x}{t} -mediumtext {m}{e}{d}{i}{u}{m}{t}{e}{x}{t} -longtext {l}{o}{n}{g}{t}{e}{x}{t} -enum {e}{n}{u}{m} -precision {p}{r}{e}{c}{i}{s}{i}{o}{n} - -signed {s}{i}{g}{n}{e}{d} -unsigned {u}{n}{s}{i}{g}{n}{e}{d} -zerofill {z}{e}{r}{o}{f}{i}{l}{l} - -bigrat {b}{i}{g}{r}{a}{t} -bool {b}{o}{o}{l} -boolean {b}{o}{o}{l}{e}{a}{n} -byte {b}{y}{t}{e} -duration {d}{u}{r}{a}{t}{i}{o}{n} -rune {r}{u}{n}{e} -string {s}{t}{r}{i}{n}{g} -use {u}{s}{e} -user {u}{s}{e}{r} -using {u}{s}{i}{n}{g} - -idchar0 [a-zA-Z_] -idchars {idchar0}|[0-9$] // See: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html -ident {idchar0}{idchars}* - -user_var "@"{ident} -sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} - -second_microsecond {s}{e}{c}{o}{n}{d}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -minute_microsecond {m}{i}{n}{u}{t}{e}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -minute_second {m}{i}{n}{u}{t}{e}_{s}{e}{c}{o}{n}{d} -hour_microsecond {h}{o}{u}{r}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -hour_second {h}{o}{u}{r}_{s}{e}{c}{o}{n}{d} -hour_minute {h}{o}{u}{r}_{m}{i}{n}{u}{t}{e} -day_microsecond {d}{a}{y}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -day_second {d}{a}{y}_{s}{e}{c}{o}{n}{d} -day_minute {d}{a}{y}_{m}{i}{n}{u}{t}{e} -day_hour {d}{a}{y}_{h}{o}{u}{r} -year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} - -%yyc c -%yyn c = l.next() -%yyt l.sc - -%x S1 S2 S3 S4 - -%% - l.val = l.val[:0] - c0, l.line, l.col = l.c, l.nline, l.ncol - -<*>\0 return 0 - -[ \t\n\r]+ -#.* -\/\/.* -\/\*([^*]|\*+[^*/])*\*+\/ --- l.sc = S3 -[ \t]+.* {l.sc = 0} -[^ \t] { - l.sc = 0 - l.c = '-' - n := len(l.val) - l.unget(l.val[n-1]) - return '-' - } - -{int_lit} return l.int(lval) -{float_lit} return l.float(lval) -{hex_lit} return l.hex(lval) -{bit_lit} return l.bit(lval) - -\" l.sc = S1 -' l.sc = S2 -` l.sc = S4 - -[^\"\\]* l.stringLit = append(l.stringLit, l.val...) -\\. l.stringLit = append(l.stringLit, l.val...) -\"\" l.stringLit = append(l.stringLit, '"') -\" l.stringLit = append(l.stringLit, '"') - l.sc = 0 - return l.str(lval, "\"") -[^'\\]* l.stringLit = append(l.stringLit, l.val...) -\\. l.stringLit = append(l.stringLit, l.val...) -'' l.stringLit = append(l.stringLit, '\'') -' l.stringLit = append(l.stringLit, '\'') - l.sc = 0 - return l.str(lval, "'") -[^`]* l.stringLit = append(l.stringLit, l.val...) -`` l.stringLit = append(l.stringLit, '`') -` l.sc = 0 - lval.item = string(l.stringLit) - l.stringLit = l.stringLit[0:0] - return identifier - -"&&" return andand -"&^" return andnot -"<<" return lsh -"<=" return le -"=" return eq -">=" return ge -"!=" return neq -"<>" return neq -"||" return oror -">>" return rsh -"<=>" return nulleq - -"@" return at -"?" return placeholder - -{abs} lval.item = string(l.val) - return abs -{add} return add -{after} lval.item = string(l.val) - return after -{all} return all -{alter} return alter -{and} return and -{any} lval.item = string(l.val) - return any -{asc} return asc -{as} return as -{auto_increment} lval.item = string(l.val) - return autoIncrement -{avg} lval.item = string(l.val) - return avg -{avg_row_length} lval.item = string(l.val) - return avgRowLength -{begin} lval.item = string(l.val) - return begin -{between} return between -{both} return both -{by} return by -{case} return caseKwd -{cast} return cast -{character} return character -{charset} lval.item = string(l.val) - return charsetKwd -{check} return check -{checksum} lval.item = string(l.val) - return checksum -{coalesce} lval.item = string(l.val) - return coalesce -{collate} return collate -{collation} lval.item = string(l.val) - return collation -{column} return column -{columns} lval.item = string(l.val) - return columns -{comment} lval.item = string(l.val) - return comment -{commit} lval.item = string(l.val) - return commit -{compression} lval.item = string(l.val) - return compression -{concat} lval.item = string(l.val) - return concat -{concat_ws} lval.item = string(l.val) - return concatWs -{connection} lval.item = string(l.val) - return connection -{constraint} return constraint -{convert} return convert -{count} lval.item = string(l.val) - return count -{create} return create -{cross} return cross -{curdate} lval.item = string(l.val) - return curDate -{current_date} lval.item = string(l.val) - return currentDate -{current_user} lval.item = string(l.val) - return currentUser -{database} lval.item = string(l.val) - return database -{databases} return databases -{date_add} lval.item = string(l.val) - return dateAdd -{date_sub} lval.item = string(l.val) - return dateSub -{day} lval.item = string(l.val) - return day -{dayofweek} lval.item = string(l.val) - return dayofweek -{dayofmonth} lval.item = string(l.val) - return dayofmonth -{dayofyear} lval.item = string(l.val) - return dayofyear -{day_hour} lval.item = string(l.val) - return dayHour -{day_microsecond} lval.item = string(l.val) - return dayMicrosecond -{day_minute} lval.item = string(l.val) - return dayMinute -{day_second} lval.item = string(l.val) - return daySecond -{deallocate} lval.item = string(l.val) - return deallocate -{default} return defaultKwd -{delayed} return delayed -{delete} return deleteKwd -{desc} return desc -{describe} return describe -{drop} return drop -{distinct} return distinct -{div} return div -{do} lval.item = string(l.val) - return do -{dual} return dual -{duplicate} lval.item = string(l.val) - return duplicate -{else} return elseKwd -{end} lval.item = string(l.val) - return end -{engine} lval.item = string(l.val) - return engine -{engines} lval.item = string(l.val) - return engines -{execute} lval.item = string(l.val) - return execute -{enum} return enum -{escape} lval.item = string(l.val) - return escape -{exists} return exists -{explain} return explain -{extract} lval.item = string(l.val) - return extract -{first} lval.item = string(l.val) - return first -{for} return forKwd -{foreign} return foreign -{found_rows} lval.item = string(l.val) - return foundRows -{from} return from -{full} lval.item = string(l.val) - return full -{fulltext} return fulltext -{grant} return grant -{grants} lval.item = string(l.val) - return grants -{group} return group -{group_concat} lval.item = string(l.val) - return groupConcat -{having} return having -{high_priority} return highPriority -{hour} lval.item = string(l.val) - return hour -{hour_microsecond} lval.item = string(l.val) - return hourMicrosecond -{hour_minute} lval.item = string(l.val) - return hourMinute -{hour_second} lval.item = string(l.val) - return hourSecond -{identified} lval.item = string(l.val) - return identified -{if} lval.item = string(l.val) - return ifKwd -{ifnull} lval.item = string(l.val) - return ifNull -{ignore} return ignore -{index} return index -{inner} return inner -{insert} return insert -{interval} return interval -{into} return into -{in} return in -{is} return is -{join} return join -{key} return key -{key_block_size} lval.item = string(l.val) - return keyBlockSize -{leading} return leading -{left} lval.item = string(l.val) - return left -{length} lval.item = string(l.val) - return length -{like} return like -{limit} return limit -{local} lval.item = string(l.val) - return local -{locate} lval.item = string(l.val) - return locate -{lock} return lock -{lower} lval.item = string(l.val) - return lower -{low_priority} return lowPriority -{max} lval.item = string(l.val) - return max -{max_rows} lval.item = string(l.val) - return maxRows -{microsecond} lval.item = string(l.val) - return microsecond -{min} lval.item = string(l.val) - return min -{minute} lval.item = string(l.val) - return minute -{minute_microsecond} lval.item = string(l.val) - return minuteMicrosecond -{minute_second} lval.item = string(l.val) - return minuteSecond -{min_rows} lval.item = string(l.val) - return minRows -{mod} return mod -{mode} lval.item = string(l.val) - return mode -{month} lval.item = string(l.val) - return month -{names} lval.item = string(l.val) - return names -{national} lval.item = string(l.val) - return national -{not} return not -{offset} lval.item = string(l.val) - return offset -{on} return on -{option} return option -{order} return order -{or} return or -{outer} return outer -{password} lval.item = string(l.val) - return password -{prepare} lval.item = string(l.val) - return prepare -{primary} return primary -{quarter} lval.item = string(l.val) - return quarter -{quick} lval.item = string(l.val) - return quick -{right} return right -{rollback} lval.item = string(l.val) - return rollback -{row} lval.item = string(l.val) - return row -{schema} lval.item = string(l.val) - return schema -{schemas} return schemas -{session} lval.item = string(l.val) - return session -{some} lval.item = string(l.val) - return some -{start} lval.item = string(l.val) - return start -{global} lval.item = string(l.val) - return global -{rand} lval.item = string(l.val) - return rand -{repeat} lval.item = string(l.val) - return repeat -{regexp} return regexp -{replace} lval.item = string(l.val) - return replace -{references} return references -{rlike} return rlike - -{sys_var} lval.item = string(l.val) - return sysVar - -{user_var} lval.item = string(l.val) - return userVar -{second} lval.item = string(l.val) - return second -{second_microsecond} lval.item= string(l.val) - return secondMicrosecond -{select} return selectKwd - -{set} return set -{share} return share -{show} return show -{substring} lval.item = string(l.val) - return substring -{substring_index} lval.item = string(l.val) - return substringIndex -{sum} lval.item = string(l.val) - return sum -{sysdate} lval.item = string(l.val) - return sysDate -{table} return tableKwd -{tables} lval.item = string(l.val) - return tables -{then} return then -{to} return to -{trailing} return trailing -{transaction} lval.item = string(l.val) - return transaction -{trim} lval.item = string(l.val) - return trim -{truncate} lval.item = string(l.val) - return truncate -{union} return union -{unique} return unique -{unknown} lval.item = string(l.val) - return unknown -{nullif} lval.item = string(l.val) - return nullIf -{update} return update -{upper} lval.item = string(l.val) - return upper -{use} return use -{user} lval.item = string(l.val) - return user -{using} return using -{value} lval.item = string(l.val) - return value -{values} return values -{variables} lval.item = string(l.val) - return variables -{warnings} lval.item = string(l.val) - return warnings -{week} lval.item = string(l.val) - return week -{weekday} lval.item = string(l.val) - return weekday -{weekofyear} lval.item = string(l.val) - return weekofyear -{when} return when -{where} return where -{xor} return xor -{yearweek} lval.item = string(l.val) - return yearweek -{year_month} lval.item = string(l.val) - return yearMonth - -{signed} lval.item = string(l.val) - return signed -{unsigned} return unsigned -{zerofill} return zerofill - -{null} lval.item = nil - return null - -{false} return falseKwd - -{true} return trueKwd - -{calc_found_rows} lval.item = string(l.val) - return calcFoundRows - -{current_ts} lval.item = string(l.val) - return currentTs -{localtime} return localTime -{localts} return localTs -{now} lval.item = string(l.val) - return now - -{bit} lval.item = string(l.val) - return bitType - -{tiny} lval.item = string(l.val) - return tinyIntType - -{tinyint} lval.item = string(l.val) - return tinyIntType - -{smallint} lval.item = string(l.val) - return smallIntType - -{mediumint} lval.item = string(l.val) - return mediumIntType - -{bigint} lval.item = string(l.val) - return bigIntType - -{decimal} lval.item = string(l.val) - return decimalType - -{numeric} lval.item = string(l.val) - return numericType - -{float} lval.item = string(l.val) - return floatType - -{double} lval.item = string(l.val) - return doubleType - -{precision} lval.item = string(l.val) - return precisionType - -{real} lval.item = string(l.val) - return realType - -{date} lval.item = string(l.val) - return dateType - -{time} lval.item = string(l.val) - return timeType - -{timestamp} lval.item = string(l.val) - return timestampType - -{datetime} lval.item = string(l.val) - return datetimeType - -{year} lval.item = string(l.val) - return yearType - -{char} lval.item = string(l.val) - return charType - -{varchar} lval.item = string(l.val) - return varcharType - -{binary} lval.item = string(l.val) - return binaryType - -{varbinary} lval.item = string(l.val) - return varbinaryType - -{tinyblob} lval.item = string(l.val) - return tinyblobType - -{blob} lval.item = string(l.val) - return blobType - -{mediumblob} lval.item = string(l.val) - return mediumblobType - -{longblob} lval.item = string(l.val) - return longblobType - -{tinytext} lval.item = string(l.val) - return tinytextType - -{mediumtext} lval.item = string(l.val) - return mediumtextType - -{text} lval.item = string(l.val) - return textType - -{longtext} lval.item = string(l.val) - return longtextType - -{bool} lval.item = string(l.val) - return boolType - -{boolean} lval.item = string(l.val) - return booleanType - -{byte} lval.item = string(l.val) - return byteType - -{int} lval.item = string(l.val) - return intType - -{integer} lval.item = string(l.val) - return integerType - -{ident} lval.item = string(l.val) - return l.handleIdent(lval) - -. return c0 - -%% - return int(unicode.ReplacementChar) -} - -func (l *lexer) npos() (line, col int) { - if line, col = l.nline, l.ncol; col == 0 { - line-- - col = l.lcol+1 - } - return -} - -func (l *lexer) str(lval *yySymType, pref string) int { - l.sc = 0 - // TODO: performance issue. - s := string(l.stringLit) - l.stringLit = l.stringLit[0:0] - if pref == "'" { - s = strings.Replace(s, "\\'", "'", -1) - s = strings.TrimSuffix(s, "'") + "\"" - pref = "\"" - } - v := stringutil.RemoveUselessBackslash(pref+s) - v, err := strconv.Unquote(v) - if err != nil { - v = strings.TrimSuffix(s, pref) - } - lval.item = v - return stringLit -} - -func (l *lexer) trimIdent(idt string) string { - idt = strings.TrimPrefix(idt, "`") - idt = strings.TrimSuffix(idt, "`") - return idt -} - -func (l *lexer) int(lval *yySymType) int { - n, err := strconv.ParseUint(string(l.val), 0, 64) - if err != nil { - l.errf("integer literal: %v", err) - return int(unicode.ReplacementChar) - } - - switch { - case n < math.MaxInt64: - lval.item = int64(n) - default: - lval.item = uint64(n) - } - return intLit -} - -func (l *lexer) float(lval *yySymType) int { - n, err := strconv.ParseFloat(string(l.val), 64) - if err != nil { - l.errf("float literal: %v", err) - return int(unicode.ReplacementChar) - } - - lval.item = float64(n) - return floatLit -} - -// https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html -func (l *lexer) hex(lval *yySymType) int { - s := string(l.val) - h, err := mysql.ParseHex(s) - if err != nil { - l.errf("hexadecimal literal: %v", err) - return int(unicode.ReplacementChar) - } - lval.item = h - return hexLit -} - -// https://dev.mysql.com/doc/refman/5.7/en/bit-type.html -func (l *lexer) bit(lval *yySymType) int { - s := string(l.val) - b, err := mysql.ParseBit(s, -1) - if err != nil { - l.errf("bit literal: %v", err) - return int(unicode.ReplacementChar) - } - lval.item = b - return bitLit -} - -func (l *lexer) handleIdent(lval *yySymType) int { - s := lval.item.(string) - // A character string literal may have an optional character set introducer and COLLATE clause: - // [_charset_name]'string' [COLLATE collation_name] - // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html - if !strings.HasPrefix(s, "_") { - return identifier - } - cs, _, err := charset.GetCharsetInfo(s[1:]) - if err != nil { - return identifier - } - lval.item = cs - return underscoreCS -} diff --git a/ast/parser/yy_parser.go b/ast/parser/yy_parser.go deleted file mode 100644 index dfc3264796..0000000000 --- a/ast/parser/yy_parser.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package parser - -// YYParse is an wrapper of `yyParse` to make it exported. -func YYParse(yylex yyLexer) int { - return yyParse(yylex) -} diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index a1805a0bcd..cf7e0e3542 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -19,13 +19,13 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb" - "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/stmt/stmts" diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go index 988d6dfaa0..b14eadb87d 100644 --- a/optimizer/infobinder_test.go +++ b/optimizer/infobinder_test.go @@ -19,10 +19,10 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/testkit" ) diff --git a/parser/parser.y b/parser/parser.y index d918b17ee1..1393a7dff1 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -26,22 +26,13 @@ package parser import ( - "fmt" "strings" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/parser/coldef" - "github.com/pingcap/tidb/ddl" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/subquery" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/rset/rsets" - "github.com/pingcap/tidb/stmt" - "github.com/pingcap/tidb/stmt/stmts" - "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) @@ -347,8 +338,8 @@ import ( %type AlterTableStmt "Alter table statement" - AlterSpecification "Alter table specification" - AlterSpecificationList "Alter table specification list" + AlterTableSpec "Alter table specification" + AlterTableSpecList "Alter table specification list" AnyOrAll "Any or All for subquery" Assignment "assignment" AssignmentList "assignment list" @@ -357,10 +348,7 @@ import ( AuthString "Password string value" BeginTransactionStmt "BEGIN TRANSACTION statement" CastType "Cast function target type" - CharsetName "Charset Name" - CollationName "Collation Name" ColumnDef "table column definition" - ColumnListOpt "Optional column list" ColumnName "column name" ColumnNameList "column name list" ColumnNameListOpt "column name list opt" @@ -370,17 +358,18 @@ import ( CommaOpt "optional comma" CommitStmt "COMMIT statement" CompareOp "Compare opcode" - Constraint "column value constraint" - ConstraintElem "table define constraint element" + ColumnOption "column definition option" + ColumnOptionList "column definition option list" + ColumnOptionListOpt "optional column definition option list" + Constraint "table constraint" + ConstraintElem "table constraint element" ConstraintKeywordOpt "Constraint Keyword or empty" - ConstraintOpt "optional column value constraint" - ConstraintOpts "optional column value constraints" CreateDatabaseStmt "Create Database Statement" CreateIndexStmt "CREATE INDEX statement" CreateIndexStmtUnique "CREATE INDEX optional UNIQUE clause" - CreateSpecification "CREATE Database specification" - CreateSpecificationList "CREATE Database specification list" - CreateSpecListOpt "CREATE Database specification list opt" + DatabaseOption "CREATE Database specification" + DatabaseOptionList "CREATE Database specification list" + DatabaseOptionListOpt "CREATE Database specification list opt" CreateTableStmt "CREATE TABLE statement" CreateUserStmt "CREATE User statement" CrossOpt "Cross join option" @@ -414,7 +403,7 @@ import ( FieldAsName "Field alias name" FieldAsNameOpt "Field alias name opt" FieldList "field expression list" - FromClause "From clause" + TableRefsClause "Table references clause" Function "function expr" FunctionCallAgg "Function call on aggregate data" FunctionCallConflict "Function call with reserved keyword as function name" @@ -425,7 +414,6 @@ import ( GlobalScope "The scope of variable" GrantStmt "Grant statement" GroupByClause "GROUP BY clause" - GroupByList "GROUP BY list" HashString "Hashed string" HavingClause "HAVING clause" IfExists "If Exists" @@ -459,11 +447,10 @@ import ( OptInteger "Optional Integer keyword" Order "ORDER BY clause optional collation specification" OrderBy "ORDER BY clause" - OrderByItem "ORDER BY list item" + ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - OrderByList "ORDER BY list" + ByList "BY list" OuterOpt "optional OUTER clause" - QualifiedIdent "qualified identifier" QuickOptional "QUICK or empty" PasswordOpt "Password option" ColumnPosition "Column position [First|After ColumnName]" @@ -489,45 +476,43 @@ import ( SelectStmtLimit "SELECT statement optional LIMIT clause" SelectStmtOpts "Select statement options" SelectStmtGroup "SELECT statement optional GROUP BY clause" - SelectStmtOrder "SELECT statement optional ORDER BY clause" SetStmt "Set variable statement" ShowStmt "Show engines/databases/tables/columns/warnings statement" ShowDatabaseNameOpt "Show tables/columns statement database name option" + ShowTableAliasOpt "Show table alias option" ShowLikeOrWhereOpt "Show like or where condition option" - ShowTableIdentOpt "Show columns statement table name option" SignedLiteral "Literal or NumLiteral with sign" - SimpleQualifiedIdent "Qualified identifier without *" Statement "statement" StatementList "statement list" + StringName "string literal or identifier" StringList "string list" ExplainableStmt "explainable statement" SubSelect "Sub Select" Symbol "Constraint Symbol" SystemVariable "System defined variable name" - TableAsOpt "table as option" - TableConstraint "table constraint definition" + TableAsName "table alias name" + TableAsNameOpt "table alias name optional" TableElement "table definition element" TableElementList "table definition element list" - TableElementListOpt "table definition element list opt" TableFactor "table factor" - TableIdent "Table identifier" - TableIdentList "Table identifier list" - TableIdentOpt "Table identifier option" - TableOpt "create table option" - TableOptList "create table option list" - TableOptListOpt "create table option list opt" + TableName "Table name" + TableNameList "Table name list" + TableOption "create table option" + TableOptionList "create table option list" + TableOptionListOpt "create table option list opt" TableRef "table reference" TableRefs "table references" TimeUnit "Time unit" TrimDirection "Trim string direction" TruncateTableStmt "TRANSACTION TABLE statement" UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionSelect "Union select/(select)" - UnionStmt "Union statement" + UnionStmt "Union select state ment" + UnionClauseList "Union select clause list" + UnionSelect "Union (select) item" UpdateStmt "UPDATE statement" Username "Username" - UserSpecification "Username and auth option" - UserSpecificationList "Username and auth option list" + UserSpec "Username and auth option" + UserSpecList "Username and auth option list" UserVariable "User defined variable name" UserVariableList "User defined variable name list" UseStmt "USE statement" @@ -574,8 +559,6 @@ import ( NUM "numbers" LengthNum "Field length num(uint64)" - - %token tableRefPriority %precedence lowerThanCalcFoundRows @@ -624,70 +607,67 @@ Start: StatementList | parseExpression Expression { - yylex.(*lexer).expr = expression.Expr($2) + yylex.(*lexer).expr = $2.(ast.ExprNode) } /**************************************AlterTableStmt*************************************** * See: https://dev.mysql.com/doc/refman/5.7/en/alter-table.html *******************************************************************************************/ AlterTableStmt: - "ALTER" IgnoreOptional "TABLE" TableIdent AlterSpecificationList + "ALTER" IgnoreOptional "TABLE" TableName AlterTableSpecList { - $$ = &stmts.AlterTableStmt{ - Ident: $4.(table.Ident), - Specs: $5.([]*ddl.AlterSpecification), + $$ = &ast.AlterTableStmt{ + Table: $4.(*ast.TableName), + Specs: $5.([]*ast.AlterTableSpec), } } -AlterSpecification: - TableOptListOpt +AlterTableSpec: + TableOptionListOpt { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterTableOpt, - TableOpts: $1.([]*coldef.TableOpt), + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableOption, + Options:$1.([]*ast.TableOption), } } | "ADD" ColumnKeywordOpt ColumnDef ColumnPosition { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterAddColumn, - Column: $3.(*coldef.ColumnDef), - Position: $4.(*ddl.ColumnPosition), + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableAddColumn, + Column: $3.(*ast.ColumnDef), + Position: $4.(*ast.ColumnPosition), } } -| "ADD" ConstraintKeywordOpt ConstraintElem +| "ADD" Constraint { - constraint := $3.(*coldef.TableConstraint) - if $2 != nil { - constraint.ConstrName = $2.(string) - } - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterAddConstr, + constraint := $2.(*ast.Constraint) + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableAddConstraint, Constraint: constraint, } } | "DROP" ColumnKeywordOpt ColumnName { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterDropColumn, - Name: $3.(string), + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableDropColumn, + DropColumn: $3.(*ast.ColumnName), } } | "DROP" "PRIMARY" "KEY" { - $$ = &ddl.AlterSpecification{Action: ddl.AlterDropPrimaryKey} + $$ = &ast.AlterTableSpec{Tp: ast.AlterTableDropPrimaryKey} } | "DROP" KeyOrIndex IndexName { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterDropIndex, + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableDropIndex, Name: $3.(string), } } | "DROP" "FOREIGN" "KEY" Symbol { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterDropForeignKey, + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableDropForeignKey, Name: $4.(string), } } @@ -701,28 +681,28 @@ ColumnKeywordOpt: ColumnPosition: { - $$ = &ddl.ColumnPosition{Type: ddl.ColumnPositionNone} + $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionNone} } | "FIRST" { - $$ = &ddl.ColumnPosition{Type: ddl.ColumnPositionFirst} + $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionFirst} } | "AFTER" ColumnName { - $$ = &ddl.ColumnPosition{ - Type: ddl.ColumnPositionAfter, - RelativeColumn: $2.(string), + $$ = &ast.ColumnPosition{ + Tp: ast.ColumnPositionAfter, + RelativeColumn: $2.(*ast.ColumnName), } } -AlterSpecificationList: - AlterSpecification +AlterTableSpecList: + AlterTableSpec { - $$ = []*ddl.AlterSpecification{$1.(*ddl.AlterSpecification)} + $$ = []*ast.AlterTableSpec{$1.(*ast.AlterTableSpec)} } -| AlterSpecificationList ',' AlterSpecification +| AlterTableSpecList ',' AlterTableSpec { - $$ = append($1.([]*ddl.AlterSpecification), $3.(*ddl.AlterSpecification)) + $$ = append($1.([]*ast.AlterTableSpec), $3.(*ast.AlterTableSpec)) } ConstraintKeywordOpt: @@ -743,183 +723,210 @@ Symbol: /*******************************************************************************************/ Assignment: - SimpleQualifiedIdent eq Expression + ColumnName eq Expression { - x, err := expression.NewAssignment($1.(string), $3.(expression.Expression)) - if err != nil { - yylex.(*lexer).errf("Parse Assignment error: %s", $1.(string)) - return 1 - } - $$ = x + $$ = &ast.Assignment{Column: $1.(*ast.ColumnName), Expr:$3.(ast.ExprNode)} } AssignmentList: Assignment { - $$ = []*expression.Assignment{$1.(*expression.Assignment)} + $$ = []*ast.Assignment{$1.(*ast.Assignment)} } | AssignmentList ',' Assignment { - $$ = append($1.([]*expression.Assignment), $3.(*expression.Assignment)) + $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) } AssignmentListOpt: /* EMPTY */ { - $$ = []*expression.Assignment{} + $$ = []*ast.Assignment{} } | AssignmentList BeginTransactionStmt: "BEGIN" { - $$ = &stmts.BeginStmt{} + $$ = &ast.BeginStmt{} } | "START" "TRANSACTION" { - $$ = &stmts.BeginStmt{} + $$ = &ast.BeginStmt{} } ColumnDef: - ColumnName Type ConstraintOpts + ColumnName Type ColumnOptionListOpt { - $$ = &coldef.ColumnDef{Name: $1.(string), Tp: $2.(*types.FieldType), Constraints: $3.([]*coldef.ConstraintOpt)} + $$ = &ast.ColumnDef{Name: $1.(*ast.ColumnName), Tp: $2.(*types.FieldType), Options: $3.([]*ast.ColumnOption)} } ColumnName: Identifier + { + $$ = &ast.ColumnName{Name: model.NewCIStr($1.(string))} + } +| Identifier '.' Identifier + { + $$ = &ast.ColumnName{Table: model.NewCIStr($1.(string)), Name: model.NewCIStr($3.(string))} + } +| Identifier '.' Identifier '.' Identifier + { + $$ = &ast.ColumnName{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string)), Name: model.NewCIStr($5.(string))} + } ColumnNameList: ColumnName { - $$ = []string{$1.(string)} + $$ = []*ast.ColumnName{$1.(*ast.ColumnName)} } | ColumnNameList ',' ColumnName { - $$ = append($1.([]string), $3.(string)) + $$ = append($1.([]*ast.ColumnName), $3.(*ast.ColumnName)) } ColumnNameListOpt: /* EMPTY */ { - $$ = []string{} + $$ = []*ast.ColumnName{} } | ColumnNameList + { + $$ = $1.([]*ast.ColumnName) + } CommitStmt: "COMMIT" { - $$ = &stmts.CommitStmt{} + $$ = &ast.CommitStmt{} } -Constraint: +ColumnOption: "NOT" "NULL" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrNotNull, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNotNull} } | "NULL" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrNull, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNull} } | "AUTO_INCREMENT" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrAutoIncrement, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionAutoIncrement} } | "PRIMARY" "KEY" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrPrimaryKey, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionPrimaryKey} } | "UNIQUE" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrUniq, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniq} } | "UNIQUE" "KEY" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrUniqKey, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniqKey} } | "DEFAULT" DefaultValueExpr { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrDefaultValue, Evalue: $2.(expression.Expression)} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionDefaultValue, Expr: $2.(ast.ExprNode)} } | "ON" "UPDATE" NowSym { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrOnUpdate, Evalue: $3.(expression.Expression)} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionOnUpdate, Expr: $3.(ast.ExprNode)} } | "COMMENT" stringLit { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrComment} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionComment} } | "CHECK" '(' Expression ')' { // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html // The CHECK clause is parsed but ignored by all storage engines. - $$ = nil + $$ = &ast.ColumnOption{} + } + +ColumnOptionList: + ColumnOption + { + $$ = []*ast.ColumnOption{$1.(*ast.ColumnOption)} + } +| ColumnOptionList ColumnOption + { + $$ = append($1.([]*ast.ColumnOption), $2.(*ast.ColumnOption)) + } + +ColumnOptionListOpt: + { + $$ = []*ast.ColumnOption{} + } +| ColumnOptionList + { + $$ = $1.([]*ast.ColumnOption) } ConstraintElem: "PRIMARY" "KEY" '(' IndexColNameList ')' { - ce := &coldef.TableConstraint{} - ce.Tp = coldef.ConstrPrimaryKey - ce.Keys = $4.([]*coldef.IndexColName) - $$ = ce + $$ = &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: $4.([]*ast.IndexColName)} } | "FULLTEXT" "KEY" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrFulltext, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintFulltext, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string), + } } | "INDEX" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrIndex, - Keys: $4.([]*coldef.IndexColName), - ConstrName: $2.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintIndex, + Keys: $4.([]*ast.IndexColName), + Name: $2.(string), + } } | "KEY" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrKey, - Keys: $4.([]*coldef.IndexColName), - ConstrName: $2.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintKey, + Keys: $4.([]*ast.IndexColName), + Name: $2.(string)} } | "UNIQUE" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrUniq, - Keys: $4.([]*coldef.IndexColName), - ConstrName: $2.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintUniq, + Keys: $4.([]*ast.IndexColName), + Name: $2.(string)} } | "UNIQUE" "INDEX" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrUniqIndex, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintUniqIndex, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string)} } | "UNIQUE" "KEY" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrUniqKey, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintUniqKey, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string)} } | "FOREIGN" "KEY" IndexName '(' IndexColNameList ')' ReferDef { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrForeignKey, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string), - Refer: $7.(*coldef.ReferenceDef), - } + $$ = &ast.Constraint{ + Tp: ast.ConstraintForeignKey, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string), + Refer: $7.(*ast.ReferenceDef), + } } ReferDef: - "REFERENCES" TableIdent '(' IndexColNameList ')' + "REFERENCES" TableName '(' IndexColNameList ')' { - $$ = &coldef.ReferenceDef{TableIdent: $2.(table.Ident), IndexColNames: $4.([]*coldef.IndexColName)} + $$ = &ast.ReferenceDef{Table: $2.(*ast.TableName), IndexColNames: $4.([]*ast.IndexColName)} } /* @@ -941,7 +948,7 @@ DefaultValueExpr: NowSym: "CURRENT_TIMESTAMP" { - $$ = &expression.Ident{CIStr: model.NewCIStr("CURRENT_TIMESTAMP")} + $$ = &ast.IdentifierExpr{Name: model.NewCIStr("CURRENT_TIMESTAMP")} } | "LOCALTIME" | "LOCALTIMESTAMP" @@ -950,62 +957,31 @@ NowSym: SignedLiteral: Literal { - $$ = expression.Value{Val: $1} + $$ = ast.NewValueExpr($1) } | '+' NumLiteral { - n := expression.Value{Val: $2} - $$ = expression.NewUnaryOperation(opcode.Plus, n) + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: ast.NewValueExpr($2)} } | '-' NumLiteral { - n := expression.Value{Val: $2} - $$ = expression.NewUnaryOperation(opcode.Minus, n) + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr($2)} } // TODO: support decimal literal NumLiteral: intLit | floatLit - -ConstraintOpt: - Constraint - -ConstraintOpts: - { - $$ = []*coldef.ConstraintOpt{} - } -| ConstraintOpts ConstraintOpt - { - if $2 != nil { - $$ = append($1.([]*coldef.ConstraintOpt), $2.(*coldef.ConstraintOpt)) - } else { - $$ = $1 - } - } - CreateIndexStmt: - "CREATE" CreateIndexStmtUnique "INDEX" Identifier "ON" TableIdent '(' IndexColNameList ')' + "CREATE" CreateIndexStmtUnique "INDEX" Identifier "ON" TableName '(' IndexColNameList ')' { - indexName, tableIdent, colNameList := $4.(string), $6.(table.Ident), $8.([]*coldef.IndexColName) - if strings.EqualFold(indexName, tableIdent.Name.O) { - yylex.(*lexer).errf("index name collision: %s", indexName) - return 1 - } - for _, colName := range colNameList { - if indexName == colName.ColumnName { - yylex.(*lexer).errf("index name collision: %s", indexName) - return 1 - } - } - - $$ = &stmts.CreateIndexStmt{ - Unique: $2.(bool), - IndexName: indexName, - TableIdent: tableIdent, - IndexColNames: colNameList, + $$ = &ast.CreateIndexStmt{ + Unique: $2.(bool), + IndexName: $4.(string), + Table: $6.(*ast.TableName), + IndexColNames: $8.([]*ast.IndexColName), } if yylex.(*lexer).root { break @@ -1025,20 +1001,20 @@ IndexColName: ColumnName OptFieldLen Order { //Order is parsed but just ignored as MySQL did - $$ = &coldef.IndexColName{ColumnName: $1.(string), Length: $2.(int)} + $$ = &ast.IndexColName{Column: $1.(*ast.ColumnName), Length: $2.(int)} } IndexColNameList: { - $$ = []*coldef.IndexColName{} + $$ = []*ast.IndexColName{} } | IndexColName { - $$ = []*coldef.IndexColName{$1.(*coldef.IndexColName)} + $$ = []*ast.IndexColName{$1.(*ast.IndexColName)} } | IndexColNameList ',' IndexColName { - $$ = append($1.([]*coldef.IndexColName), $3.(*coldef.IndexColName)) + $$ = append($1.([]*ast.IndexColName), $3.(*ast.IndexColName)) } @@ -1054,30 +1030,14 @@ IndexColNameList: * | [DEFAULT] COLLATE [=] collation_name *******************************************************************/ CreateDatabaseStmt: - "CREATE" DatabaseSym IfNotExists DBName CreateSpecListOpt + "CREATE" DatabaseSym IfNotExists DBName DatabaseOptionListOpt { - opts := $5.([]*coldef.DatabaseOpt) - //compose charset from x - var cs, co string - for _, x := range opts { - switch x.Tp { - case coldef.DBOptCharset, coldef.DBOptCollate: - cs = x.Value - } + $$ = &ast.CreateDatabaseStmt{ + IfNotExists: $3.(bool), + Name: $4.(string), + Options: $5.([]*ast.DatabaseOption), } - ok := charset.ValidCharsetAndCollation(cs, co) - if !ok { - yylex.(*lexer).errf("Unknown character set %s or collate %s ", cs, co) - return 1 - } - dbopt := &coldef.CharsetOpt{Chs: cs, Col: co} - - $$ = &stmts.CreateDatabaseStmt{ - IfNotExists: $3.(bool), - Name: $4.(string), - Opt: dbopt} - if yylex.(*lexer).root { break } @@ -1086,55 +1046,30 @@ CreateDatabaseStmt: DBName: Identifier -CreateSpecification: - DefaultKwdOpt CharsetKw EqOpt CharsetName +DatabaseOption: + DefaultKwdOpt CharsetKw EqOpt StringName { - $$ = &coldef.DatabaseOpt{Tp: coldef.DBOptCollate, Value: $4.(string)} + $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCharset, Value: $4.(string)} } -| DefaultKwdOpt "COLLATE" EqOpt CollationName +| DefaultKwdOpt "COLLATE" EqOpt StringName { - $$ = &coldef.DatabaseOpt{Tp: coldef.DBOptCollate, Value: $4.(string)} + $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: $4.(string)} } -CharsetName: - Identifier +DatabaseOptionListOpt: { - c := strings.ToLower($1.(string)) - if charset.ValidCharsetAndCollation(c, "") { - $$ = c - } else { - yylex.(*lexer).errf("Unknown character set: '%s'", $1.(string)) - return 1 - } - } -| stringLit - { - c := strings.ToLower($1.(string)) - if charset.ValidCharsetAndCollation(c, "") { - $$ = c - } else { - yylex.(*lexer).errf("Unknown character set: '%s'", $1.(string)) - return 1 - } + $$ = []*ast.DatabaseOption{} } +| DatabaseOptionList -CollationName: - Identifier - -CreateSpecListOpt: +DatabaseOptionList: + DatabaseOption { - $$ = []*coldef.DatabaseOpt{} + $$ = []*ast.DatabaseOption{$1.(*ast.DatabaseOption)} } -| CreateSpecificationList - -CreateSpecificationList: - CreateSpecification +| DatabaseOptionList DatabaseOption { - $$ = []*coldef.DatabaseOpt{$1.(*coldef.DatabaseOpt)} - } -| CreateSpecificationList CreateSpecification - { - $$ = append($1.([]*coldef.DatabaseOpt), $2.(*coldef.DatabaseOpt)) + $$ = append($1.([]*ast.DatabaseOption), $2.(*ast.DatabaseOption)) } /******************************************************************* @@ -1153,45 +1088,30 @@ CreateSpecificationList: * ) *******************************************************************/ CreateTableStmt: - "CREATE" "TABLE" IfNotExists TableIdent '(' TableElementListOpt ')' TableOptListOpt + "CREATE" "TABLE" IfNotExists TableName '(' TableElementList ')' TableOptionListOpt { tes := $6.([]interface {}) - var columnDefs []*coldef.ColumnDef - var tableConstraints []*coldef.TableConstraint + var columnDefs []*ast.ColumnDef + var constraints []*ast.Constraint for _, te := range tes { switch te := te.(type) { - case *coldef.ColumnDef: + case *ast.ColumnDef: columnDefs = append(columnDefs, te) - case *coldef.TableConstraint: - tableConstraints = append(tableConstraints, te) + case *ast.Constraint: + constraints = append(constraints, te) } } if len(columnDefs) == 0 { yylex.(*lexer).err("Column Definition List can't be empty.") return 1 } - - opt := &coldef.TableOption{} - if $8 != nil { - for _, o := range $8.([]*coldef.TableOpt) { - switch o.Tp { - case coldef.TblOptEngine: - opt.Engine = o.StrValue - case coldef.TblOptCharset: - opt.Charset = o.StrValue - case coldef.TblOptCollate: - opt.Collate = o.StrValue - case coldef.TblOptAutoIncrement: - opt.AutoIncrement = o.UintValue - } - } - } - $$ = &stmts.CreateTableStmt{ - Ident: $4.(table.Ident), + $$ = &ast.CreateTableStmt{ + Table: $4.(*ast.TableName), IfNotExists: $3.(bool), Cols: columnDefs, - Constraints: tableConstraints, - Opt: opt} + Constraints: constraints, + Options: $8.([]*ast.TableOption), + } } Default: @@ -1217,8 +1137,8 @@ DefaultKwdOpt: DoStmt: "DO" ExpressionList { - $$ = &stmts.DoStmt { - Exprs: $2.([]expression.Expression), + $$ = &ast.DoStmt { + Exprs: $2.([]ast.ExprNode), } } @@ -1228,27 +1148,24 @@ DoStmt: * *******************************************************************/ DeleteFromStmt: - "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableIdent WhereClauseOptional OrderByOptional LimitClause + "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause { // Single Table - ts := &rsets.TableSource{Source: $6} - r := &rsets.JoinRset{Left: ts, Right: nil} - x := &stmts.DeleteStmt{ - Refs: r, + join := &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil} + x := &ast.DeleteStmt{ + TableRefs: &ast.TableRefsClause{TableRefs: join}, LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), } if $7 != nil { - x.Where = $7.(expression.Expression) + x.Where = $7.(ast.ExprNode) } - if $8 != nil { - x.Order = $8.(*rsets.OrderByRset) + x.Order = $8.(*ast.OrderByClause) } - if $9 != nil { - x.Limit = $9.(*rsets.LimitRset) + x.Limit = $9.(*ast.Limit) } $$ = x @@ -1256,39 +1173,39 @@ DeleteFromStmt: break } } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional TableIdentList "FROM" TableRefs WhereClauseOptional +| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional TableNameList "FROM" TableRefs WhereClauseOptional { // Multiple Table - x := &stmts.DeleteStmt{ + x := &ast.DeleteStmt{ LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), MultiTable: true, BeforeFrom: true, - TableIdents: $5.([]table.Ident), - Refs: $7.(*rsets.JoinRset), + Tables: $5.([]*ast.TableName), + TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, } if $8 != nil { - x.Where = $8.(expression.Expression) + x.Where = $8.(ast.ExprNode) } $$ = x if yylex.(*lexer).root { break } } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableIdentList "USING" TableRefs WhereClauseOptional +| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableNameList "USING" TableRefs WhereClauseOptional { // Multiple Table - x := &stmts.DeleteStmt{ + x := &ast.DeleteStmt{ LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), MultiTable: true, - TableIdents: $6.([]table.Ident), - Refs: $8.(*rsets.JoinRset), + Tables: $6.([]*ast.TableName), + TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, } if $9 != nil { - x.Where = $9.(expression.Expression) + x.Where = $9.(ast.ExprNode) } $$ = x if yylex.(*lexer).root { @@ -1302,29 +1219,29 @@ DatabaseSym: DropDatabaseStmt: "DROP" DatabaseSym IfExists DBName { - $$ = &stmts.DropDatabaseStmt{IfExists: $3.(bool), Name: $4.(string)} + $$ = &ast.DropDatabaseStmt{IfExists: $3.(bool), Name: $4.(string)} if yylex.(*lexer).root { break } } DropIndexStmt: - "DROP" "INDEX" IfExists Identifier + "DROP" "INDEX" IfExists Identifier "ON" TableName { - $$ = &stmts.DropIndexStmt{IfExists: $3.(bool), IndexName: $4.(string)} + $$ = &ast.DropIndexStmt{IfExists: $3.(bool), IndexName: $4.(string), Table: $6.(*ast.TableName)} } DropTableStmt: - "DROP" TableOrTables TableIdentList + "DROP" TableOrTables TableNameList { - $$ = &stmts.DropTableStmt{TableIdents: $3.([]table.Ident)} + $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} if yylex.(*lexer).root { break } } -| "DROP" "TABLE" "IF" "EXISTS" TableIdentList +| "DROP" TableOrTables "IF" "EXISTS" TableNameList { - $$ = &stmts.DropTableStmt{IfExists: true, TableIdents: $5.([]table.Ident)} + $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} if yylex.(*lexer).root { break } @@ -1353,26 +1270,28 @@ ExplainSym: | "DESC" ExplainStmt: - ExplainSym TableIdent + ExplainSym TableName { - $$ = &stmts.ExplainStmt{ - S:&stmts.ShowStmt{ - Target: stmt.ShowColumns, - TableIdent: $2.(table.Ident)}, + $$ = &ast.ExplainStmt{ + Stmt: &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $2.(*ast.TableName), + }, } } -| ExplainSym TableIdent ColumnName +| ExplainSym TableName ColumnName { - $$ = &stmts.ExplainStmt{ - S:&stmts.ShowStmt{ - Target: stmt.ShowColumns, - TableIdent: $2.(table.Ident), - ColumnName: $3.(string)}, + $$ = &ast.ExplainStmt{ + Stmt: &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $2.(*ast.TableName), + Column: $3.(*ast.ColumnName), + }, } } | ExplainSym ExplainableStmt { - $$ = &stmts.ExplainStmt{S:$2.(stmt.Statement)} + $$ = &ast.ExplainStmt{Stmt: $2.(ast.StmtNode)} } LengthNum: @@ -1392,32 +1311,32 @@ NUM: Expression: Expression logOr Expression %prec oror { - $$ = expression.NewBinaryOperation(opcode.OrOr, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.OrOr, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | Expression "XOR" Expression %prec xor { - $$ = expression.NewBinaryOperation(opcode.LogicXor, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.LogicXor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | Expression logAnd Expression %prec andand { - $$ = expression.NewBinaryOperation(opcode.AndAnd, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.AndAnd, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | "NOT" Expression %prec not { - $$ = expression.NewUnaryOperation(opcode.Not, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} } | Factor "IS" NotOpt trueKwd %prec is { - $$ = &expression.IsTruth{Expr:$1.(expression.Expression), Not: $3.(bool), True: int64(1)} + $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(1)} } | Factor "IS" NotOpt falseKwd %prec is { - $$ = &expression.IsTruth{Expr:$1.(expression.Expression), Not: $3.(bool), True: int64(0)} + $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(0)} } | Factor "IS" NotOpt "UNKNOWN" %prec is { /* https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#operator_is */ - $$ = &expression.IsNull{Expr: $1.(expression.Expression), Not: $3.(bool)} + $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} } | Factor @@ -1444,31 +1363,31 @@ name: ExpressionList: Expression { - $$ = []expression.Expression{expression.Expr($1)} + $$ = []ast.ExprNode{$1.(ast.ExprNode)} } | ExpressionList ',' Expression { - $$ = append($1.([]expression.Expression), expression.Expr($3)) + $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) } ExpressionListOpt: { - $$ = []expression.Expression{} + $$ = []ast.ExprNode{} } | ExpressionList Factor: Factor "IS" NotOpt "NULL" %prec is { - $$ = &expression.IsNull{Expr: $1.(expression.Expression), Not: $3.(bool)} + $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} } | Factor CompareOp PredicateExpr %prec eq { - $$ = expression.NewBinaryOperation($2.(opcode.Op), $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | Factor CompareOp AnyOrAll SubSelect %prec eq { - $$ = expression.NewCompareSubQuery($2.(opcode.Op), $1.(expression.Expression), $4.(*subquery.SubQuery), $3.(bool)) + $$ = &ast.CompareSubqueryExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $4.(*ast.SubqueryExpr), All: $3.(bool)} } | PredicateExpr @@ -1523,19 +1442,19 @@ AnyOrAll: PredicateExpr: PrimaryFactor NotOpt "IN" '(' ExpressionList ')' { - $$ = &expression.PatternIn{Expr: $1.(expression.Expression), Not: $2.(bool), List: $5.([]expression.Expression)} + $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), List: $5.([]ast.ExprNode)} } | PrimaryFactor NotOpt "IN" SubSelect { - $$ = &expression.PatternIn{Expr: $1.(expression.Expression), Not: $2.(bool), Sel: $4.(*subquery.SubQuery)} + $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), Sel: $4.(*ast.SubqueryExpr)} } | PrimaryFactor NotOpt "BETWEEN" PrimaryFactor "AND" PredicateExpr { - var err error - $$, err = expression.NewBetween($1.(expression.Expression), $4.(expression.Expression), $6.(expression.Expression), $2.(bool)) - if err != nil { - yylex.(*lexer).err(err) - return 1 + $$ = &ast.BetweenExpr{ + Expr: $1.(ast.ExprNode), + Left: $4.(ast.ExprNode), + Right: $6.(ast.ExprNode), + Not: $2.(bool), } } | PrimaryFactor NotOpt "LIKE" PrimaryExpression LikeEscapeOpt @@ -1547,15 +1466,16 @@ PredicateExpr: } else if len(escape) == 0 { escape = "\\" } - $$ = &expression.PatternLike{ - Expr: $1.(expression.Expression), - Pattern: $4.(expression.Expression), - Not: $2.(bool), - Escape: escape[0]} + $$ = &ast.PatternLikeExpr{ + Expr: $1.(ast.ExprNode), + Pattern: $4.(ast.ExprNode), + Not: $2.(bool), + Escape: escape[0], + } } | PrimaryFactor NotOpt RegexpSym PrimaryExpression { - $$ = &expression.PatternRegexp{Expr: $1.(expression.Expression), Pattern: $4.(expression.Expression), Not: $2.(bool)} + $$ = &ast.PatternRegexpExpr{Expr: $1.(ast.ExprNode), Pattern: $4.(ast.ExprNode), Not: $2.(bool)} } | PrimaryFactor @@ -1585,12 +1505,29 @@ NotOpt: Field: '*' { - $$ = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr("*")}} + $$ = &ast.SelectField{WildCard: &ast.WildCardField{}} + } +| Identifier '.' '*' + { + wildCard := &ast.WildCardField{Table: model.NewCIStr($1.(string))} + $$ = &ast.SelectField{WildCard: wildCard} + } +| Identifier '.' Identifier '.' '*' + { + wildCard := &ast.WildCardField{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Expression FieldAsNameOpt { - expr, name := expression.Expr($1), $2.(string) - $$ = &field.Field{Expr: expr, AsName: name} + expr := $1.(ast.ExprNode) + asName := $2.(string) + if asName != "" { + // Set expr original text. + offset := yyS[yypt-1].offset + end := yyS[yypt].offset-1 + expr.SetText(yylex.(*lexer).src[offset:end]) + } + $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} } FieldAsNameOpt: @@ -1624,27 +1561,28 @@ FieldAsName: FieldList: Field { - $$ = []*field.Field{$1.(*field.Field)} + field := $1.(*ast.SelectField) + field.Offset = yyS[yypt].offset + $$ = []*ast.SelectField{field} } | FieldList ',' Field { - $$ = append($1.([]*field.Field), $3.(*field.Field)) + + fl := $1.([]*ast.SelectField) + last := fl[len(fl)-1] + if last.Expr != nil && last.AsName.O == "" { + lastEnd := yyS[yypt-1].offset-1 // Comma offset. + last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) + } + newField := $3.(*ast.SelectField) + newField.Offset = yyS[yypt].offset + $$ = append(fl, newField) } GroupByClause: - "GROUP" "BY" GroupByList + "GROUP" "BY" ByList { - $$ = &rsets.GroupByRset{By: $3.([]expression.Expression)} - } - -GroupByList: - Expression - { - $$ = []expression.Expression{$1.(expression.Expression)} - } -| GroupByList ',' Expression - { - $$ = append($1.([]expression.Expression), $3.(expression.Expression)) + $$ = &ast.GroupByClause{Items: $3.([]*ast.ByItem)} } HavingClause: @@ -1653,7 +1591,7 @@ HavingClause: } | "HAVING" Expression { - $$ = &rsets.HavingRset{Expr:$2.(expression.Expression)} + $$ = &ast.HavingClause{Expr: $2.(ast.ExprNode)} } IfExists: @@ -1726,13 +1664,15 @@ NotKeywordToken: * TODO: support PARTITION **********************************************************************************/ InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableIdent InsertValues OnDuplicateKeyUpdate + "INSERT" Priority IgnoreOptional IntoOpt TableName InsertValues OnDuplicateKeyUpdate { - x := &stmts.InsertIntoStmt{InsertValues: $6.(stmts.InsertValues)} + x := $6.(*ast.InsertStmt) x.Priority = $2.(int) - x.TableIdent = $5.(table.Ident) + // Wraps many layers here so that it can be processed the same way as select statement. + ts := &ast.TableSource{Source: $5.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} if $7 != nil { - x.OnDuplicate = $7.([]*expression.Assignment) + x.OnDuplicate = $7.([]*ast.Assignment) } $$ = x if yylex.(*lexer).root { @@ -1750,33 +1690,34 @@ IntoOpt: InsertValues: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { - $$ = stmts.InsertValues{ - ColNames: $2.([]string), - Lists: $5.([][]expression.Expression)} + $$ = &ast.InsertStmt{ + Columns: $2.([]*ast.ColumnName), + Lists: $5.([][]ast.ExprNode), + } } | '(' ColumnNameListOpt ')' SelectStmt { - $$ = stmts.InsertValues{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)} + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} } | '(' ColumnNameListOpt ')' UnionStmt { - $$ = stmts.InsertValues{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)} + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.UnionStmt)} } | ValueSym ExpressionListList %prec insertValues { - $$ = stmts.InsertValues{Lists: $2.([][]expression.Expression)} + $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} } | SelectStmt { - $$ = stmts.InsertValues{Sel: $1.(*stmts.SelectStmt)} + $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} } | UnionStmt { - $$ = stmts.InsertValues{Sel: $1.(*stmts.UnionStmt)} + $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} } | "SET" ColumnSetValueList { - $$ = stmts.InsertValues{Setlist: $2.([]*expression.Assignment)} + $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} } ValueSym: @@ -1786,40 +1727,41 @@ ValueSym: ExpressionListList: '(' ')' { - $$ = [][]expression.Expression{[]expression.Expression{}} + $$ = [][]ast.ExprNode{[]ast.ExprNode{}} } | '(' ')' ',' ExpressionListList { - $$ = append([][]expression.Expression{[]expression.Expression{}}, $4.([][]expression.Expression)...) + $$ = append([][]ast.ExprNode{[]ast.ExprNode{}}, $4.([][]ast.ExprNode)...) } | '(' ExpressionList ')' { - $$ = [][]expression.Expression{$2.([]expression.Expression)} + $$ = [][]ast.ExprNode{$2.([]ast.ExprNode)} } | '(' ExpressionList ')' ',' ExpressionListList { - $$ = append([][]expression.Expression{$2.([]expression.Expression)}, $5.([][]expression.Expression)...) + $$ = append([][]ast.ExprNode{$2.([]ast.ExprNode)}, $5.([][]ast.ExprNode)...) } ColumnSetValue: ColumnName eq Expression { - $$ = &expression.Assignment{ - ColName: $1.(string), - Expr: expression.Expr($3)} + $$ = &ast.Assignment{ + Column: $1.(*ast.ColumnName), + Expr: $3.(ast.ExprNode), + } } ColumnSetValueList: { - $$ = []*expression.Assignment{} + $$ = []*ast.Assignment{} } | ColumnSetValue { - $$ = []*expression.Assignment{$1.(*expression.Assignment)} + $$ = []*ast.Assignment{$1.(*ast.Assignment)} } | ColumnSetValueList ',' ColumnSetValue { - $$ = append($1.([]*expression.Assignment), $3.(*expression.Assignment)) + $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) } /* @@ -1835,34 +1777,40 @@ OnDuplicateKeyUpdate: $$ = $5 } +/***********************************Insert Statements END************************************/ + /************************************************************************************ - * Replace Statments + * Replace Statements * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html * * TODO: support PARTITION **********************************************************************************/ ReplaceIntoStmt: - "REPLACE" ReplacePriority IntoOpt TableIdent InsertValues + "REPLACE" ReplacePriority IntoOpt TableName InsertValues { - x := &stmts.ReplaceIntoStmt{InsertValues: $5.(stmts.InsertValues)} + x := $5.(*ast.InsertStmt) + x.Replace = true x.Priority = $2.(int) - x.TableIdent = $4.(table.Ident) + ts := &ast.TableSource{Source: $4.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} $$ = x } ReplacePriority: { - $$ = stmts.NoPriority + $$ = ast.NoPriority } | "LOW_PRIORITY" { - $$ = stmts.LowPriority + $$ = ast.LowPriority } | "DELAYED" { - $$ = stmts.DelayedPriority + $$ = ast.DelayedPriority } +/***********************************Replace Statments END************************************/ + Literal: "false" { @@ -1908,23 +1856,23 @@ Literal: Operand: Literal { - $$ = expression.Value{Val: $1} + $$ = ast.NewValueExpr($1) } -| QualifiedIdent +| ColumnName { - $$ = &expression.Ident{CIStr: model.NewCIStr($1.(string))} + $$ = &ast.ColumnNameExpr{Name: $1.(*ast.ColumnName)} } | '(' Expression ')' { - $$ = &expression.PExpr{Expr: expression.Expr($2)} + $$ = &ast.ParenthesesExpr{Expr: $2.(ast.ExprNode)} } | "DEFAULT" %prec lowerThanLeftParen { - $$ = &expression.Default{} + $$ = &ast.DefaultExpr{} } | "DEFAULT" '(' ColumnName ')' { - $$ = &expression.Default{Name: $3.(string)} + $$ = &ast.DefaultExpr{Name: $3.(*ast.ColumnName)} } | Variable { @@ -1932,64 +1880,59 @@ Operand: } | "PLACEHOLDER" { - l := yylex.(*lexer) - if !l.prepare { - l.err("Can not accept placeholder when not parsing prepare sql") - return 1 + $$ = &ast.ParamMarkerExpr{ + Offset: yyS[yypt].offset, } - pm := &expression.ParamMarker{} - l.ParamList = append(l.ParamList, pm) - $$ = pm } | "ROW" '(' Expression ',' ExpressionList ')' { - values := append([]expression.Expression{$3.(expression.Expression)}, $5.([]expression.Expression)...) - $$ = &expression.Row{Values: values} + values := append([]ast.ExprNode{$3.(ast.ExprNode)}, $5.([]ast.ExprNode)...) + $$ = &ast.RowExpr{Values: values} } | '(' Expression ',' ExpressionList ')' { - values := append([]expression.Expression{$2.(expression.Expression)}, $4.([]expression.Expression)...) - $$ = &expression.Row{Values: values} + values := append([]ast.ExprNode{$2.(ast.ExprNode)}, $4.([]ast.ExprNode)...) + $$ = &ast.RowExpr{Values: values} } | "EXISTS" SubSelect { - $$ = &expression.ExistsSubQuery{Sel: $2.(*subquery.SubQuery)} + $$ = &ast.ExistsSubqueryExpr{Sel: $2.(*ast.SubqueryExpr)} } OrderBy: - "ORDER" "BY" OrderByList + "ORDER" "BY" ByList { - $$ = &rsets.OrderByRset{By: $3.([]rsets.OrderByItem)} + $$ = &ast.OrderByClause{Items: $3.([]*ast.ByItem)} } -OrderByList: - OrderByItem +ByList: + ByItem { - $$ = []rsets.OrderByItem{$1.(rsets.OrderByItem)} + $$ = []*ast.ByItem{$1.(*ast.ByItem)} } -| OrderByList ',' OrderByItem +| ByList ',' ByItem { - $$ = append($1.([]rsets.OrderByItem), $3.(rsets.OrderByItem)) + $$ = append($1.([]*ast.ByItem), $3.(*ast.ByItem)) } -OrderByItem: +ByItem: Expression Order { - $$ = rsets.OrderByItem{Expr: $1.(expression.Expression), Asc: $2.(bool)} + $$ = &ast.ByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} } Order: /* EMPTY */ { - $$ = true // ASC by default + $$ = false // ASC by default } | "ASC" { - $$ = true + $$ = false } | "DESC" { - $$ = false + $$ = true } OrderByOptional: @@ -2007,19 +1950,19 @@ PrimaryExpression: | SubSelect | '!' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.Not, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} } | '~' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.BitNeg, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.BitNeg, V: $2.(ast.ExprNode)} } | '-' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.Minus, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: $2.(ast.ExprNode)} } | '+' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.Plus, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: $2.(ast.ExprNode)} } | "BINARY" PrimaryExpression %prec neg { @@ -2027,13 +1970,13 @@ PrimaryExpression: x := types.NewFieldType(mysql.TypeString) x.Charset = charset.CharsetBin x.Collate = charset.CharsetBin - $$ = &expression.FunctionCast{ - Expr: $2.(expression.Expression), + $$ = &ast.FuncCastExpr{ + Expr: $2.(ast.ExprNode), Tp: x, - FunctionType: expression.BinaryOperator, + FunctionType: ast.CastBinaryOperator, } } -| PrimaryExpression "COLLATE" CollationName %prec neg +| PrimaryExpression "COLLATE" StringName %prec neg { // TODO: Create a builtin function hold expr and collation. When do evaluation, convert expr result using the collation. $$ = $1 @@ -2051,34 +1994,16 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - x := yylex.(*lexer) - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - x.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - x := yylex.(*lexer) - var err error - $$, err = expression.NewCall($1.(string), []expression.Expression{}, false) - if err != nil { - x.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_DATE" { - x := yylex.(*lexer) - var err error - $$, err = expression.NewCall($1.(string), []expression.Expression{}, false) - if err != nil { - x.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } DistinctOpt: @@ -2099,604 +2024,346 @@ DistinctOpt: } FunctionCallKeyword: - "AVG" '(' DistinctOpt ExpressionList ')' - { - var err error - $$, err = expression.NewCall($1.(string), $4.([]expression.Expression), $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } - } -| "CAST" '(' Expression "AS" CastType ')' + "CAST" '(' Expression "AS" CastType ')' { /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ - $$ = &expression.FunctionCast{ - Expr: $3.(expression.Expression), + $$ = &ast.FuncCastExpr{ + Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), - FunctionType: expression.CastFunction, - } + FunctionType: ast.CastFunction, + } } | "CASE" ExpressionOpt WhenClauseList ElseOpt "END" { - x := &expression.FunctionCase{WhenClauses: $3.([]*expression.WhenClause)} + x := &ast.CaseExpr{WhenClauses: $3.([]*ast.WhenClause)} if $2 != nil { - x.Value = $2.(expression.Expression) + x.Value = $2.(ast.ExprNode) } if $4 != nil { - x.ElseClause = $4.(expression.Expression) + x.ElseClause = $4.(ast.ExprNode) } $$ = x } -| "CONVERT" '(' Expression "USING" CharsetName ')' +| "CONVERT" '(' Expression "USING" StringName ')' { // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &expression.FunctionConvert{ - Expr: $3.(expression.Expression), + $$ = &ast.FuncConvertExpr{ + Expr: $3.(ast.ExprNode), Charset: $5.(string), } } -| "CONVERT" '(' Expression ',' CastType ')' +| "CONVERT" '(' Expression ',' CastType ')' { // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &expression.FunctionCast{ - Expr: $3.(expression.Expression), + $$ = &ast.FuncCastExpr{ + Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), - FunctionType: expression.ConvertFunction, + FunctionType: ast.CastConvertFunction, } } | "DATE" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "USER" '(' ')' { - args := []expression.Expression{} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } -| "VALUES" '(' Identifier ')' %prec lowerThanInsertValues +| "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { // TODO: support qualified identifier for column_name - $$ = &expression.Values{CIStr: model.NewCIStr($3.(string))} + $$ = &ast.ValuesExpr{Column: $3.(*ast.ColumnName)} } | "WEEK" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "YEAR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: "COALESCE" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURDATE" '(' ')' { - var err error - $$, err = expression.NewCall($1.(string), []expression.Expression{}, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_TIMESTAMP" FuncDatetimePrec { - args := []expression.Expression{} + args := []ast.ExprNode{} if $2 != nil { - args = append(args, $2.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $2.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "ABS" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "CONCAT" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CONCAT_WS" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "DAY" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFWEEK" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFMONTH" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFYEAR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { - $$ = &expression.DateArith{ - Op:expression.DateAdd, - Unit: $7.(string), - Date: $3.(expression.Expression), - Interval: $6.(expression.Expression), + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateAdd, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), } } | "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { - $$ = &expression.DateArith{ - Op:expression.DateSub, + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateSub, Unit: $7.(string), - Date: $3.(expression.Expression), - Interval: $6.(expression.Expression), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), } } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { - $$ = &expression.Extract{ - Unit: $3.(string), - Date: $5.(expression.Expression), + $$ = &ast.FuncExtractExpr{ + Unit: $3.(string), + Date: $5.(ast.ExprNode), } } | "FOUND_ROWS" '(' ')' { - args := []expression.Expression{} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "HOUR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "IFNULL" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "LOCATE" '(' Expression ',' Expression ')' { - $$ = &expression.FunctionLocate{ - SubStr: $3.(expression.Expression), - Str: $5.(expression.Expression), + $$ = &ast.FuncLocateExpr{ + SubStr: $3.(ast.ExprNode), + Str: $5.(ast.ExprNode), } } | "LOCATE" '(' Expression ',' Expression ',' Expression ')' { - $$ = &expression.FunctionLocate{ - SubStr: $3.(expression.Expression), - Str: $5.(expression.Expression), - Pos: $7.(expression.Expression), + $$ = &ast.FuncLocateExpr{ + SubStr: $3.(ast.ExprNode), + Str: $5.(ast.ExprNode), + Pos: $7.(ast.ExprNode), } } | "LOWER" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MICROSECOND" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MINUTE" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MONTH" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "NOW" '(' ExpressionOpt ')' { - args := []expression.Expression{} + args := []ast.ExprNode{} if $3 != nil { - args = append(args, $3.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $3.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "NULLIF" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "RAND" '(' ExpressionOpt ')' { - args := []expression.Expression{} + args := []ast.ExprNode{} if $3 != nil { - args = append(args, $3.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $3.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "REPLACE" '(' Expression ',' Expression ',' Expression ')' { - args := []expression.Expression{$3.(expression.Expression), - $5.(expression.Expression), - $7.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "SECOND" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "SUBSTRING" '(' Expression ',' Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), } } | "SUBSTRING" '(' Expression "FROM" Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), } } | "SUBSTRING" '(' Expression ',' Expression ',' Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), - Len: $7.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), + Len: $7.(ast.ExprNode), } } | "SUBSTRING" '(' Expression "FROM" Expression "FOR" Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), - Len: $7.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), + Len: $7.(ast.ExprNode), } } | "SUBSTRING_INDEX" '(' Expression ',' Expression ',' Expression ')' { - $$ = &expression.FunctionSubstringIndex{ - StrExpr: $3.(expression.Expression), - Delim: $5.(expression.Expression), - Count: $7.(expression.Expression), + $$ = &ast.FuncSubstringIndexExpr{ + StrExpr: $3.(ast.ExprNode), + Delim: $5.(ast.ExprNode), + Count: $7.(ast.ExprNode), } } | "SYSDATE" '(' ExpressionOpt ')' { - args := []expression.Expression{} + args := []ast.ExprNode{} if $3 != nil { - args = append(args, $3.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $3.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "TRIM" '(' Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $3.(expression.Expression), + $$ = &ast.FuncTrimExpr{ + Str: $3.(ast.ExprNode), } } | "TRIM" '(' Expression "FROM" Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $5.(expression.Expression), - RemStr: $3.(expression.Expression), + $$ = &ast.FuncTrimExpr{ + Str: $5.(ast.ExprNode), + RemStr: $3.(ast.ExprNode), } } | "TRIM" '(' TrimDirection "FROM" Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $5.(expression.Expression), - Direction: $3.(int), + $$ = &ast.FuncTrimExpr{ + Str: $5.(ast.ExprNode), + Direction: $3.(ast.TrimDirectionType), } } | "TRIM" '(' TrimDirection Expression "FROM" Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $6.(expression.Expression), - RemStr: $4.(expression.Expression), - Direction: $3.(int), + $$ = &ast.FuncTrimExpr{ + Str: $6.(ast.ExprNode), + RemStr: $4.(ast.ExprNode), + Direction: $3.(ast.TrimDirectionType), } } | "UPPER" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKDAY" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKOFYEAR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "YEARWEEK" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression),false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } TrimDirection: "BOTH" { - $$ = expression.TrimBoth + $$ = ast.TrimBoth } | "LEADING" { - $$ = expression.TrimLeading + $$ = ast.TrimLeading } | "TRAILING" { - $$ = expression.TrimTrailing + $$ = ast.TrimTrailing } FunctionCallAgg: - "COUNT" '(' DistinctOpt ExpressionList ')' + "AVG" '(' DistinctOpt ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $4.([]expression.Expression), $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + } +| "COUNT" '(' DistinctOpt ExpressionList ')' + { + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt '*' ')' { - var err error - args := []expression.Expression{ expression.Value{Val: expression.TypeStar("*")} } - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + args := []ast.ExprNode{ast.NewValueExpr(ast.UnquoteString("*"))} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $4.([]expression.Expression),$3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "MAX" '(' DistinctOpt Expression ')' { - args := []expression.Expression{$4.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "MIN" '(' DistinctOpt Expression ')' { - args := []expression.Expression{$4.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "SUM" '(' DistinctOpt Expression ')' { - args := []expression.Expression{$4.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } FuncDatetimePrec: @@ -2730,19 +2397,19 @@ ExpressionOpt: WhenClauseList: WhenClause { - $$ = []*expression.WhenClause{$1.(*expression.WhenClause)} + $$ = []*ast.WhenClause{$1.(*ast.WhenClause)} } | WhenClauseList WhenClause { - $$ = append($1.([]*expression.WhenClause), $2.(*expression.WhenClause)) + $$ = append($1.([]*ast.WhenClause), $2.(*ast.WhenClause)) } WhenClause: "WHEN" Expression "THEN" Expression { - $$ = &expression.WhenClause{ - Expr: $2.(expression.Expression), - Result: $4.(expression.Expression), + $$ = &ast.WhenClause{ + Expr: $2.(ast.ExprNode), + Result: $4.(ast.ExprNode), } } @@ -2788,7 +2455,7 @@ CastType: } | "DECIMAL" FloatOpt { - fopt := $2.(*coldef.FloatOpt) + fopt := $2.(*ast.FloatOpt) x := types.NewFieldType(mysql.TypeNewDecimal) x.Flen = fopt.Flen x.Decimal = fopt.Decimal @@ -2816,69 +2483,70 @@ CastType: PrimaryFactor: PrimaryFactor '|' PrimaryFactor %prec '|' { - $$ = expression.NewBinaryOperation(opcode.Or, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Or, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '&' PrimaryFactor %prec '&' { - $$ = expression.NewBinaryOperation(opcode.And, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.And, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor "<<" PrimaryFactor %prec lsh { - $$ = expression.NewBinaryOperation(opcode.LeftShift, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.LeftShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor ">>" PrimaryFactor %prec rsh { - $$ = expression.NewBinaryOperation(opcode.RightShift, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.RightShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '+' PrimaryFactor %prec '+' { - $$ = expression.NewBinaryOperation(opcode.Plus, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Plus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '-' PrimaryFactor %prec '-' { - $$ = expression.NewBinaryOperation(opcode.Minus, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Minus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '*' PrimaryFactor %prec '*' { - $$ = expression.NewBinaryOperation(opcode.Mul, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Mul, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '/' PrimaryFactor %prec '/' { - $$ = expression.NewBinaryOperation(opcode.Div, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Div, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '%' PrimaryFactor %prec '%' { - $$ = expression.NewBinaryOperation(opcode.Mod, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor "DIV" PrimaryFactor %prec div { - $$ = expression.NewBinaryOperation(opcode.IntDiv, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.IntDiv, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor "MOD" PrimaryFactor %prec mod { - $$ = expression.NewBinaryOperation(opcode.Mod, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '^' PrimaryFactor { - $$ = expression.NewBinaryOperation(opcode.Xor, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Xor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryExpression + Priority: { - $$ = stmts.NoPriority + $$ = ast.NoPriority } | "LOW_PRIORITY" { - $$ = stmts.LowPriority + $$ = ast.LowPriority } | "HIGH_PRIORITY" { - $$ = stmts.HighPriority + $$ = ast.HighPriority } | "DELAYED" { - $$ = stmts.DelayedPriority + $$ = ast.DelayedPriority } LowPriorityOptional: @@ -2890,55 +2558,25 @@ LowPriorityOptional: $$ = true } -QualifiedIdent: +TableName: Identifier -| Identifier '.' '*' { - $$ = fmt.Sprintf("%s.*", $1.(string)) + $$ = &ast.TableName{Name:model.NewCIStr($1.(string))} } | Identifier '.' Identifier { - $$ = fmt.Sprintf("%s.%s", $1.(string), $3.(string)) - } -| Identifier '.' Identifier '.' Identifier - { - $$ = fmt.Sprintf("%s.%s.%s", $1.(string), $3.(string), $5.(string)) - } -| Identifier '.' Identifier '.' '*' - { - $$ = fmt.Sprintf("%s.%s.*", $1.(string), $3.(string)) + $$ = &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} } -SimpleQualifiedIdent: - Identifier -| Identifier '.' Identifier +TableNameList: + TableName { - $$ = fmt.Sprintf("%s.%s", $1.(string), $3.(string)) - } -| Identifier '.' Identifier '.' Identifier - { - $$ = fmt.Sprintf("%s.%s.%s", $1.(string), $3.(string), $5.(string)) - } - -TableIdent: - Identifier - { - $$ = table.Ident{Name:model.NewCIStr($1.(string))} - } -| Identifier '.' Identifier - { - $$ = table.Ident{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - } - -TableIdentList: - TableIdent - { - tbl := []table.Ident{$1.(table.Ident)} + tbl := []*ast.TableName{$1.(*ast.TableName)} $$ = tbl } -| TableIdentList ',' TableIdent +| TableNameList ',' TableName { - $$ = append($1.([]table.Ident), $3.(table.Ident)) + $$ = append($1.([]*ast.TableName), $3.(*ast.TableName)) } QuickOptional: @@ -2951,7 +2589,6 @@ QuickOptional: $$ = true } - /***************************Prepared Statement Start****************************** * See: https://dev.mysql.com/doc/refman/5.7/en/prepare.html * Example: @@ -2965,14 +2602,14 @@ PreparedStmt: "PREPARE" Identifier "FROM" PrepareSQL { var sqlText string - var sqlVar *expression.Variable + var sqlVar *ast.VariableExpr switch $4.(type) { case string: sqlText = $4.(string) - case *expression.Variable: - sqlVar = $4.(*expression.Variable) + case *ast.VariableExpr: + sqlVar = $4.(*ast.VariableExpr) } - $$ = &stmts.PreparedStmt{ + $$ = &ast.PrepareStmt{ InPrepare: true, Name: $2.(string), SQLText: sqlText, @@ -2995,24 +2632,24 @@ PrepareSQL: ExecuteStmt: "EXECUTE" Identifier { - $$ = &stmts.ExecuteStmt{Name: $2.(string)} + $$ = &ast.ExecuteStmt{Name: $2.(string)} } | "EXECUTE" Identifier "USING" UserVariableList { - $$ = &stmts.ExecuteStmt{ + $$ = &ast.ExecuteStmt{ Name: $2.(string), - UsingVars: $4.([]expression.Expression), + UsingVars: $4.([]ast.ExprNode), } } UserVariableList: UserVariable { - $$ = []expression.Expression{$1.(expression.Expression)} + $$ = []ast.ExprNode{$1.(ast.ExprNode)} } | UserVariableList ',' UserVariable { - $$ = append($1.([]expression.Expression), $3.(expression.Expression)) + $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) } /* @@ -3022,7 +2659,7 @@ UserVariableList: DeallocateStmt: DeallocateSym "PREPARE" Identifier { - $$ = &stmts.DeallocateStmt{Name: $3.(string)} + $$ = &ast.DeallocateStmt{Name: $3.(string)} } DeallocateSym: @@ -3034,93 +2671,122 @@ DeallocateSym: RollbackStmt: "ROLLBACK" { - $$ = &stmts.RollbackStmt{} + $$ = &ast.RollbackStmt{} } SelectStmt: "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt { - $$ = &stmts.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*field.Field), - Lock: $5.(coldef.LockType), + Fields: $3.(*ast.FieldList), + LockTp: $5.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + src := yylex.(*lexer).src + var lastEnd int + if $4 != nil { + lastEnd = yyS[yypt-1].offset-1 + } else if $5 != ast.SelectLockNone { + lastEnd = yyS[yypt].offset-1 + } else { + lastEnd = len(src) + if src[lastEnd-1] == ';' { + lastEnd-- + } + } + lastField.SetText(src[lastField.Offset:lastEnd]) + } + if $4 != nil { + st.Limit = $4.(*ast.Limit) + } + $$ = st } | "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt { - st := &stmts.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*field.Field), - From: nil, - Lock: $7.(coldef.LockType), + Fields: $3.(*ast.FieldList), + LockTp: $7.(ast.SelectLockType), + } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-3].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) } - if $5 != nil { - st.Where = &rsets.WhereRset{Expr: $5.(expression.Expression)} + st.Where = $5.(ast.ExprNode) + } + if $6 != nil { + st.Limit = $6.(*ast.Limit) } - $$ = st } -| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause WhereClauseOptional SelectStmtGroup HavingClause SelectStmtOrder +| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" + TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional SelectStmtLimit SelectLockOpt { - st := &stmts.SelectStmt{ + st := &ast.SelectStmt{ Distinct: $2.(bool), - Fields: $3.([]*field.Field), - From: $5.(*rsets.JoinRset), - Lock: $11.(coldef.LockType), + Fields: $3.(*ast.FieldList), + From: $5.(*ast.TableRefsClause), + LockTp: $11.(ast.SelectLockType), + } + + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-7].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) } if $6 != nil { - st.Where = &rsets.WhereRset{Expr: $6.(expression.Expression)} + st.Where = $6.(ast.ExprNode) } if $7 != nil { - st.GroupBy = $7.(*rsets.GroupByRset) + st.GroupBy = $7.(*ast.GroupByClause) } if $8 != nil { - st.Having = $8.(*rsets.HavingRset) + st.Having = $8.(*ast.HavingClause) } if $9 != nil { - st.OrderBy = $9.(*rsets.OrderByRset) + st.OrderBy = $9.(*ast.OrderByClause) } if $10 != nil { - ay := $10.([]interface{}) - st.Limit = ay[0].(*rsets.LimitRset) - st.Offset = ay[1].(*rsets.OffsetRset) + st.Limit = $10.(*ast.Limit) } - + $$ = st } FromDual: - "FROM" "DUAL" + "FROM" "DUAL" -FromClause: +TableRefsClause: TableRefs { - $$ = $1 + $$ = &ast.TableRefsClause{TableRefs: $1.(*ast.Join)} } TableRefs: EscapedTableRef { - r := &rsets.JoinRset{Left: $1, Right: nil} - if j, ok := $1.(*rsets.JoinRset); ok { - // if $1 is JoinRset, use it directly - r = j - } - $$ = r + if j, ok := $1.(*ast.Join); ok { + // if $1 is Join, use it directly + $$ = j + } else { + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} + } } | TableRefs ',' EscapedTableRef { /* from a, b is default cross join */ - $$ = &rsets.JoinRset{Left: $1, Right: $3, Type: rsets.CrossJoin} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} } EscapedTableRef: @@ -3148,66 +2814,70 @@ TableRef: } TableFactor: - TableIdent TableIdentOpt + TableName TableAsNameOpt { - $$ = &rsets.TableSource{Source: $1, Name: $2.(string)} + $$ = &ast.TableSource{Source: $1.(*ast.TableName), AsName: $2.(model.CIStr)} } -| '(' SelectStmt ')' TableAsOpt +| '(' SelectStmt ')' TableAsName { - $$ = &rsets.TableSource{Source: $2, Name: $4.(string)} + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) + $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } -| '(' UnionStmt ')' TableAsOpt +| '(' UnionStmt ')' TableAsName { - $$ = &rsets.TableSource{Source: $2, Name: $4.(string)} + $$ = &ast.TableSource{Source: $2.(*ast.UnionStmt), AsName: $4.(model.CIStr)} } | '(' TableRefs ')' { $$ = $2 } -TableIdentOpt: +TableAsNameOpt: { - $$ = "" + $$ = model.CIStr{} } -| TableAsOpt +| TableAsName { $$ = $1 } -TableAsOpt: +TableAsName: Identifier { - $$ = $1 + $$ = model.NewCIStr($1.(string)) } | "AS" Identifier { - $$ = $2 + $$ = model.NewCIStr($2.(string)) } JoinTable: /* Use %prec to evaluate production TableRef before cross join */ TableRef CrossOpt TableRef %prec tableRefPriority { - $$ = &rsets.JoinRset{Left: $1, Right: $3, Type: rsets.CrossJoin} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} } | TableRef CrossOpt TableRef "ON" Expression { - $$ = &rsets.JoinRset{Left: $1, Right: $3, Type: rsets.CrossJoin, On: $5.(expression.Expression)} + on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} } | TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - $$ = &rsets.JoinRset{Left: $1, Right: $5, Type: $2.(rsets.JoinType), On: $7.(expression.Expression)} - } + { + on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} + } /* Support Using */ JoinType: "LEFT" { - $$ = rsets.LeftJoin + $$ = ast.LeftJoin } | "RIGHT" { - $$ = rsets.RightJoin + $$ = ast.RightJoin } OuterOpt: @@ -3225,11 +2895,11 @@ CrossOpt: LimitClause: { - $$ = (*rsets.LimitRset)(nil) + $$ = nil } | "LIMIT" LengthNum { - $$ = &rsets.LimitRset{Count: $2.(uint64)} + $$ = &ast.Limit{Count: $2.(uint64)} } SelectStmtLimit: @@ -3238,21 +2908,15 @@ SelectStmtLimit: } | "LIMIT" LengthNum { - $$ = []interface{}{ - &rsets.LimitRset{Count: $2.(uint64)}, - (*rsets.OffsetRset)(nil)} + $$ = &ast.Limit{Count: $2.(uint64)} } | "LIMIT" LengthNum ',' LengthNum { - $$ = []interface{}{ - &rsets.LimitRset{Count: $4.(uint64)}, - &rsets.OffsetRset{Count: $2.(uint64)}} + $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} } | "LIMIT" LengthNum "OFFSET" LengthNum { - $$ = []interface{}{ - &rsets.LimitRset{Count: $2.(uint64)}, - &rsets.OffsetRset{Count: $4.(uint64)}} + $$ = &ast.Limit{Offset: $4.(uint64), Count: $2.(uint64)} } SelectStmtDistinct: @@ -3289,7 +2953,7 @@ SelectStmtCalcFoundRows: SelectStmtFieldList: FieldList { - $$ = $1 + $$ = &ast.FieldList{Fields: $1.([]*ast.SelectField)} } SelectStmtGroup: @@ -3299,94 +2963,157 @@ SelectStmtGroup: } | GroupByClause -SelectStmtOrder: - /* EMPTY */ - { - $$ = nil - } -| OrderBy - - // See: https://dev.mysql.com/doc/refman/5.7/en/subqueries.html SubSelect: '(' SelectStmt ')' { - s := $2.(*stmts.SelectStmt) + s := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &subquery.SubQuery{Stmt: s} + $$ = &ast.SubqueryExpr{Query: s} } | '(' UnionStmt ')' { - s := $2.(*stmts.UnionStmt) + s := $2.(*ast.UnionStmt) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &subquery.SubQuery{Stmt: s} + $$ = &ast.SubqueryExpr{Query: s} } // See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html SelectLockOpt: /* empty */ { - $$ = coldef.SelectLockNone + $$ = ast.SelectLockNone } | "FOR" "UPDATE" { - $$ = coldef.SelectLockForUpdate + $$ = ast.SelectLockForUpdate } | "LOCK" "IN" "SHARE" "MODE" { - $$ = coldef.SelectLockInShareMode + $$ = ast.SelectLockInShareMode } +// See: https://dev.mysql.com/doc/refman/5.7/en/union.html +UnionStmt: + UnionClauseList "UNION" UnionOpt SelectStmt + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union + } +| UnionClauseList "UNION" UnionOpt '(' SelectStmt ')' OrderByOptional SelectStmtLimit + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) + st := $5.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, st) + if $7 != nil { + union.OrderBy = $7.(*ast.OrderByClause) + } + if $8 != nil { + union.Limit = $8.(*ast.Limit) + } + $$ = union + } + +UnionClauseList: + UnionSelect + { + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Selects: selects, + } + } +| UnionClauseList "UNION" UnionOpt UnionSelect + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union + } + +UnionSelect: + SelectStmt +| '(' SelectStmt ')' + { + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) + $$ = st + } + +UnionOpt: + { + $$ = true + } +| "ALL" + { + $$ = false + } +| "DISTINCT" + { + $$ = true + } + + /********************Set Statement*******************************/ SetStmt: "SET" VariableAssignmentList { - $$ = &stmts.SetStmt{Variables: $2.([]*stmts.VariableAssignment)} + $$ = &ast.SetStmt{Variables: $2.([]*ast.VariableAssignment)} } -| "SET" "NAMES" CharsetName +| "SET" "NAMES" StringName { - $$ = &stmts.SetCharsetStmt{Charset: $3.(string)} + $$ = &ast.SetCharsetStmt{Charset: $3.(string)} } -| "SET" "NAMES" CharsetName "COLLATE" CollationName +| "SET" "NAMES" StringName "COLLATE" StringName { - $$ = &stmts.SetCharsetStmt{ + $$ = &ast.SetCharsetStmt{ Charset: $3.(string), Collate: $5.(string), } } -| "SET" CharsetKw CharsetName +| "SET" CharsetKw StringName { - $$ = &stmts.SetCharsetStmt{Charset: $3.(string)} + $$ = &ast.SetCharsetStmt{Charset: $3.(string)} } | "SET" "PASSWORD" eq PasswordOpt { - $$ = &stmts.SetPwdStmt{Password: $4.(string)} + $$ = &ast.SetPwdStmt{Password: $4.(string)} } | "SET" "PASSWORD" "FOR" Username eq PasswordOpt { - $$ = &stmts.SetPwdStmt{User: $4.(string), Password: $6.(string)} + $$ = &ast.SetPwdStmt{User: $4.(string), Password: $6.(string)} } VariableAssignment: Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $1.(string), Value: $3.(expression.Expression), IsSystem: true} + $$ = &ast.VariableAssignment{Name: $1.(string), Value: $3.(ast.ExprNode), IsSystem: true} } | "GLOBAL" Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $2.(string), Value: $4.(expression.Expression), IsGlobal: true, IsSystem: true} + $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsGlobal: true, IsSystem: true} } | "SESSION" Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $2.(string), Value: $4.(expression.Expression), IsSystem: true} + $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} } | "LOCAL" Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $2.(string), Value: $4.(expression.Expression), IsSystem: true} + $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} } | "SYS_VAR" eq Expression { @@ -3402,26 +3129,26 @@ VariableAssignment: } else if strings.HasPrefix(v, "@@") { v = strings.TrimPrefix(v, "@@") } - $$ = &stmts.VariableAssignment{Name: v, Value: $3.(expression.Expression), IsGlobal: isGlobal, IsSystem: true} + $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode), IsGlobal: isGlobal, IsSystem: true} } | "USER_VAR" eq Expression { v := $1.(string) v = strings.TrimPrefix(v, "@") - $$ = &stmts.VariableAssignment{Name: v, Value: $3.(expression.Expression)} + $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode)} } VariableAssignmentList: { - $$ = []*stmts.VariableAssignment{} + $$ = []*ast.VariableAssignment{} } | VariableAssignment { - $$ = []*stmts.VariableAssignment{$1.(*stmts.VariableAssignment)} + $$ = []*ast.VariableAssignment{$1.(*ast.VariableAssignment)} } | VariableAssignmentList ',' VariableAssignment { - $$ = append($1.([]*stmts.VariableAssignment), $3.(*stmts.VariableAssignment)) + $$ = append($1.([]*ast.VariableAssignment), $3.(*ast.VariableAssignment)) } Variable: @@ -3442,7 +3169,7 @@ SystemVariable: } else if strings.HasPrefix(v, "@@") { v = strings.TrimPrefix(v, "@@") } - $$ = &expression.Variable{Name: v, IsGlobal: isGlobal, IsSystem: true} + $$ = &ast.VariableExpr{Name: v, IsGlobal: isGlobal, IsSystem: true} } UserVariable: @@ -3450,7 +3177,7 @@ UserVariable: { v := $1.(string) v = strings.TrimPrefix(v, "@") - $$ = &expression.Variable{Name: v, IsGlobal: false, IsSystem: false} + $$ = &ast.VariableExpr{Name: v, IsGlobal: false, IsSystem: false} } Username: @@ -3479,77 +3206,91 @@ AuthString: ShowStmt: "SHOW" "ENGINES" { - $$ = &stmts.ShowStmt{Target: stmt.ShowEngines} + $$ = &ast.ShowStmt{Tp: ast.ShowEngines} } | "SHOW" "DATABASES" { - $$ = &stmts.ShowStmt{Target: stmt.ShowDatabases} + $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} } | "SHOW" "SCHEMAS" { - $$ = &stmts.ShowStmt{Target: stmt.ShowDatabases} + $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} } | "SHOW" "CHARACTER" "SET" { - $$ = &stmts.ShowStmt{Target: stmt.ShowCharset} + $$ = &ast.ShowStmt{Tp: ast.ShowCharset} } | "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt { - stmt := &stmts.ShowStmt{ - Target: stmt.ShowTables, - DBName: $4.(string), - Full: $2.(bool), + stmt := &ast.ShowStmt{ + Tp: ast.ShowTables, + DBName: $4.(string), + Full: $2.(bool), + } + if $5 != nil { + if x, ok := $5.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { + stmt.Where = $5.(ast.ExprNode) + } } - stmt.SetCondition($5) $$ = stmt } -| "SHOW" OptFull "COLUMNS" ShowTableIdentOpt ShowDatabaseNameOpt +| "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt { - $$ = &stmts.ShowStmt{ - Target: stmt.ShowColumns, - TableIdent: $4.(table.Ident), - DBName: $5.(string), - Full: $2.(bool), + $$ = &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $4.(*ast.TableName), + DBName: $5.(string), + Full: $2.(bool), } } | "SHOW" "WARNINGS" { - $$ = &stmts.ShowStmt{Target: stmt.ShowWarnings} + $$ = &ast.ShowStmt{Tp: ast.ShowWarnings} } | "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt { - stmt := &stmts.ShowStmt{ - Target: stmt.ShowVariables, + stmt := &ast.ShowStmt{ + Tp: ast.ShowVariables, GlobalScope: $2.(bool), } - stmt.SetCondition($4) + if x, ok := $4.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else if $4 != nil { + stmt.Where = $4.(ast.ExprNode) + } $$ = stmt } | "SHOW" "COLLATION" ShowLikeOrWhereOpt { - stmt := &stmts.ShowStmt{ - Target: stmt.ShowCollation, + stmt := &ast.ShowStmt{ + Tp: ast.ShowCollation, + } + if x, ok := $3.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else if $3 != nil { + stmt.Where = $3.(ast.ExprNode) } - stmt.SetCondition($3) $$ = stmt } -| "SHOW" "CREATE" "TABLE" TableIdent +| "SHOW" "CREATE" "TABLE" TableName { - $$ = &stmts.ShowStmt{ - Target: stmt.ShowCreateTable, - TableIdent: $4.(table.Ident), + $$ = &ast.ShowStmt{ + Tp: ast.ShowCreateTable, + Table: $4.(*ast.TableName), } } | "SHOW" "GRANTS" { // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html - $$ = &stmts.ShowStmt{Target: stmt.ShowGrants} + $$ = &ast.ShowStmt{Tp: ast.ShowGrants} } | "SHOW" "GRANTS" "FOR" Username { // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html - $$ = &stmts.ShowStmt{ - Target: stmt.ShowGrants, + $$ = &ast.ShowStmt{ + Tp: ast.ShowGrants, User: $4.(string), } } @@ -3560,11 +3301,11 @@ ShowLikeOrWhereOpt: } | "LIKE" PrimaryExpression { - $$ = &expression.PatternLike{Pattern: $2.(expression.Expression)} + $$ = &ast.PatternLikeExpr{Pattern: $2.(ast.ExprNode)} } | "WHERE" Expression { - $$ = expression.Expr($2) + $$ = $2.(ast.ExprNode) } GlobalScope: @@ -3602,14 +3343,14 @@ ShowDatabaseNameOpt: $$ = $2.(string) } -ShowTableIdentOpt: - "FROM" TableIdent +ShowTableAliasOpt: + "FROM" TableName { - $$ = $2.(table.Ident) + $$ = $2.(*ast.TableName) } -| "IN" TableIdent +| "IN" TableName { - $$ = $2.(table.Ident) + $$ = $2.(*ast.TableName) } Statement: @@ -3635,17 +3376,17 @@ Statement: | RollbackStmt | ReplaceIntoStmt | SelectStmt +| UnionStmt | SetStmt | ShowStmt | TruncateTableStmt -| UnionStmt | UpdateStmt | UseStmt | SubSelect { // `(select 1)`; is a valid select statement // TODO: This is used to fix issue #320. There may be a better solution. - $$ = $1.(*subquery.SubQuery).Stmt + $$ = $1.(*ast.SubqueryExpr).Query } ExplainableStmt: @@ -3659,46 +3400,38 @@ StatementList: Statement { if $1 != nil { - n, ok := $1.(ast.Node) - if ok { - n.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = []interface{}{n} - } else { - s := $1.(stmt.Statement) - s.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = []interface{}{s} - } + s := $1.(ast.StmtNode) + s.SetText(yylex.(*lexer).stmtText()) + yylex.(*lexer).list = append(yylex.(*lexer).list, s) } } | StatementList ';' Statement { if $3 != nil { - s := $3.(stmt.Statement) + s := $3.(ast.StmtNode) s.SetText(yylex.(*lexer).stmtText()) yylex.(*lexer).list = append(yylex.(*lexer).list, s) } } -TableConstraint: - "CONSTRAINT" name ConstraintElem +Constraint: + ConstraintKeywordOpt ConstraintElem { - cst := $3.(*coldef.TableConstraint) - cst.ConstrName = $2.(string) + cst := $2.(*ast.Constraint) + if $1 != nil { + cst.Name = $1.(string) + } $$ = cst } -| ConstraintElem - { - $$ = $1 - } TableElement: ColumnDef { - $$ = $1.(*coldef.ColumnDef) + $$ = $1.(*ast.ColumnDef) } -| TableConstraint +| Constraint { - $$ = $1.(*coldef.TableConstraint) + $$ = $1.(*ast.Constraint) } | "CHECK" '(' Expression ')' { @@ -3724,96 +3457,90 @@ TableElementList: } } -TableElementListOpt: - { - $$ = []interface{}{} - } -| TableElementList - -TableOpt: +TableOption: "ENGINE" Identifier { - $$ = &coldef.TableOpt{Tp: coldef.TblOptEngine, StrValue: $2.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $2.(string)} } | "ENGINE" eq Identifier { - $$ = &coldef.TableOpt{Tp: coldef.TblOptEngine, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $3.(string)} } -| DefaultKwdOpt CharsetKw EqOpt CharsetName +| DefaultKwdOpt CharsetKw EqOpt StringName { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCharset, StrValue: $4.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionCharset, StrValue: $4.(string)} } -| DefaultKwdOpt "COLLATE" EqOpt CollationName +| DefaultKwdOpt "COLLATE" EqOpt StringName { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCollate, StrValue: $4.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: $4.(string)} } | "AUTO_INCREMENT" eq LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptAutoIncrement, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionAutoIncrement, UintValue: $3.(uint64)} } | "COMMENT" EqOpt stringLit { - $$ = &coldef.TableOpt{Tp: coldef.TblOptComment, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionComment, StrValue: $3.(string)} } | "AVG_ROW_LENGTH" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptAvgRowLength, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionAvgRowLength, UintValue: $3.(uint64)} } | "CONNECTION" EqOpt stringLit { - $$ = &coldef.TableOpt{Tp: coldef.TblOptConnection, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionConnection, StrValue: $3.(string)} } | "CHECKSUM" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCheckSum, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionCheckSum, UintValue: $3.(uint64)} } | "PASSWORD" EqOpt stringLit { - $$ = &coldef.TableOpt{Tp: coldef.TblOptPassword, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionPassword, StrValue: $3.(string)} } | "COMPRESSION" EqOpt Identifier { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCompression, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionCompression, StrValue: $3.(string)} } | "KEY_BLOCK_SIZE" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptKeyBlockSize, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionKeyBlockSize, UintValue: $3.(uint64)} } | "MAX_ROWS" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptMaxRows, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionMaxRows, UintValue: $3.(uint64)} } | "MIN_ROWS" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptMinRows, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionMinRows, UintValue: $3.(uint64)} } -TableOptListOpt: +TableOptionListOpt: { - $$ = []*coldef.TableOpt{} + $$ = []*ast.TableOption{} } -| TableOptList %prec lowerThanComma +| TableOptionList %prec lowerThanComma -TableOptList: - TableOpt +TableOptionList: + TableOption { - $$ = []*coldef.TableOpt{$1.(*coldef.TableOpt)} + $$ = []*ast.TableOption{$1.(*ast.TableOption)} } -| TableOptList TableOpt +| TableOptionList TableOption { - $$ = append($1.([]*coldef.TableOpt), $2.(*coldef.TableOpt)) + $$ = append($1.([]*ast.TableOption), $2.(*ast.TableOption)) } -| TableOptList ',' TableOpt +| TableOptionList ',' TableOption { - $$ = append($1.([]*coldef.TableOpt), $3.(*coldef.TableOpt)) + $$ = append($1.([]*ast.TableOption), $3.(*ast.TableOption)) } TruncateTableStmt: - "TRUNCATE" "TABLE" TableIdent + "TRUNCATE" "TABLE" TableName { - $$ = &stmts.TruncateTableStmt{TableIdent: $3.(table.Ident)} + $$ = &ast.TruncateTableStmt{Table: $3.(*ast.TableName)} } /*************************************Type Begin***************************************/ @@ -3879,7 +3606,7 @@ NumericType: } | FixedPointType FloatOpt FieldOpts { - fopt := $2.(*coldef.FloatOpt) + fopt := $2.(*ast.FloatOpt) x := types.NewFieldType($1.(byte)) x.Flen = fopt.Flen x.Decimal = fopt.Decimal @@ -3895,7 +3622,7 @@ NumericType: } | FloatingPointType FloatOpt FieldOpts { - fopt := $2.(*coldef.FloatOpt) + fopt := $2.(*ast.FloatOpt) x := types.NewFieldType($1.(byte)) x.Flen = fopt.Flen if x.Tp == mysql.TypeFloat { @@ -3927,7 +3654,6 @@ NumericType: x.Flen = 1 } else if x.Flen > 64 { yylex.(*lexer).errf("invalid field length %d for bit type, must in [1, 64]", x.Flen) - return 1 } $$ = x } @@ -4011,8 +3737,6 @@ StringType: if $4.(bool) { x.Flag |= mysql.BinaryFlag } - x.Charset = $5.(string) - x.Collate = $6.(string) $$ = x } | NationalOpt "CHAR" OptBinary OptCharset OptCollate @@ -4021,8 +3745,6 @@ StringType: if $3.(bool) { x.Flag |= mysql.BinaryFlag } - x.Charset = $4.(string) - x.Collate = $5.(string) $$ = x } | NationalOpt "VARCHAR" FieldLen OptBinary OptCharset OptCollate @@ -4216,21 +3938,21 @@ FieldOpts: FloatOpt: { - $$ = &coldef.FloatOpt{Flen: types.UnspecifiedLength, Decimal: types.UnspecifiedLength} + $$ = &ast.FloatOpt{Flen: types.UnspecifiedLength, Decimal: types.UnspecifiedLength} } | FieldLen { - $$ = &coldef.FloatOpt{Flen: $1.(int), Decimal: types.UnspecifiedLength} + $$ = &ast.FloatOpt{Flen: $1.(int), Decimal: types.UnspecifiedLength} } | Precision { - $$ = $1.(*coldef.FloatOpt) + $$ = $1.(*ast.FloatOpt) } Precision: '(' LengthNum ',' LengthNum ')' { - $$ = &coldef.FloatOpt{Flen: int($2.(uint64)), Decimal: int($4.(uint64))} + $$ = &ast.FloatOpt{Flen: int($2.(uint64)), Decimal: int($4.(uint64))} } OptBinary: @@ -4246,7 +3968,7 @@ OptCharset: { $$ = "" } -| CharsetKw CharsetName +| CharsetKw StringName { $$ = $2.(string) } @@ -4259,7 +3981,7 @@ OptCollate: { $$ = "" } -| "COLLATE" CollationName +| "COLLATE" StringName { $$ = $2.(string) } @@ -4274,71 +3996,15 @@ StringList: $$ = append($1.([]string), $3.(string)) } -/************************************************************************************ - * Union statement - * See: https://dev.mysql.com/doc/refman/5.7/en/union.html - ***********************************************************************************/ -UnionStmt: - UnionSelect "UNION" UnionOpt SelectStmt +StringName: + stringLit { - ds := []bool {$3.(bool)} - ss := []*stmts.SelectStmt{$1.(*stmts.SelectStmt), $4.(*stmts.SelectStmt)} - $$ = &stmts.UnionStmt{ - Distincts: ds, - Selects: ss, - } + $$ = $1.(string) } -| UnionSelect "UNION" UnionOpt '(' SelectStmt ')' SelectStmtOrder SelectStmtLimit +| Identifier { - ds := []bool {$3.(bool)} - ss := []*stmts.SelectStmt{$1.(*stmts.SelectStmt), $5.(*stmts.SelectStmt)} - st := &stmts.UnionStmt{ - Distincts: ds, - Selects: ss, - } - if $7 != nil { - st.OrderBy = $7.(*rsets.OrderByRset) - } - - if $8 != nil { - ay := $8.([]interface{}) - st.Limit = ay[0].(*rsets.LimitRset) - st.Offset = ay[1].(*rsets.OffsetRset) - } - $$ = st + $$ = $1.(string) } -| UnionSelect "UNION" UnionOpt UnionStmt - { - s := $4.(*stmts.UnionStmt) - s.Distincts = append([]bool {$3.(bool)}, s.Distincts...) - s.Selects = append([]*stmts.SelectStmt{$1.(*stmts.SelectStmt)}, s.Selects...) - $$ = s - } - -UnionSelect: - SelectStmt - { - $$ = $1 - } -| '(' SelectStmt ')' - { - $$ = $2 - } - - -UnionOpt: - { - $$ = true - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } - /*********************************************************************************** * Update Statement @@ -4347,21 +4013,25 @@ UnionOpt: UpdateStmt: "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { - // Single-table syntax - r := &rsets.JoinRset{Left: $4, Right: nil} - st := &stmts.UpdateStmt{ + var refs *ast.Join + if x, ok := $4.(*ast.Join); ok { + refs = x + } else { + refs = &ast.Join{Left: $4.(ast.ResultSetNode)} + } + st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: r, - List: $6.([]*expression.Assignment), + TableRefs: &ast.TableRefsClause{TableRefs: refs}, + List: $6.([]*ast.Assignment), } if $7 != nil { - st.Where = $7.(expression.Expression) + st.Where = $7.(ast.ExprNode) } if $8 != nil { - st.Order = $8.(*rsets.OrderByRset) + st.Order = $8.(*ast.OrderByClause) } if $9 != nil { - st.Limit = $9.(*rsets.LimitRset) + st.Limit = $9.(*ast.Limit) } $$ = st if yylex.(*lexer).root { @@ -4370,15 +4040,13 @@ UpdateStmt: } | "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional { - // Multiple-table syntax - st := &stmts.UpdateStmt{ + st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: $4.(*rsets.JoinRset), - List: $6.([]*expression.Assignment), - MultipleTable: true, + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, + List: $6.([]*ast.Assignment), } if $7 != nil { - st.Where = $7.(expression.Expression) + st.Where = $7.(ast.ExprNode) } $$ = st if yylex.(*lexer).root { @@ -4389,7 +4057,7 @@ UpdateStmt: UseStmt: "USE" DBName { - $$ = &stmts.UseStmt{DBName: $2.(string)} + $$ = &ast.UseStmt{DBName: $2.(string)} if yylex.(*lexer).root { break } @@ -4398,7 +4066,7 @@ UseStmt: WhereClause: "WHERE" Expression { - $$ = expression.Expr($2) + $$ = $2.(ast.ExprNode) } WhereClauseOptional: @@ -4422,36 +4090,35 @@ CommaOpt: * https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html ************************************************************************************/ CreateUserStmt: - "CREATE" "USER" IfNotExists UserSpecificationList + "CREATE" "USER" IfNotExists UserSpecList { // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html - $$ = &stmts.CreateUserStmt{ + $$ = &ast.CreateUserStmt{ IfNotExists: $3.(bool), - Specs: $4.([]*coldef.UserSpecification), + Specs: $4.([]*ast.UserSpec), } } -UserSpecification: +UserSpec: Username AuthOption { - x := &coldef.UserSpecification{ + userSpec := &ast.UserSpec{ User: $1.(string), } if $2 != nil { - x.AuthOpt = $2.(*coldef.AuthOption) + userSpec.AuthOpt = $2.(*ast.AuthOption) } - $$ = x - + $$ = userSpec } -UserSpecificationList: - UserSpecification +UserSpecList: + UserSpec { - $$ = []*coldef.UserSpecification{$1.(*coldef.UserSpecification)} + $$ = []*ast.UserSpec{$1.(*ast.UserSpec)} } -| UserSpecificationList ',' UserSpecification +| UserSpecList ',' UserSpec { - $$ = append($1.([]*coldef.UserSpecification), $3.(*coldef.UserSpecification)) + $$ = append($1.([]*ast.UserSpec), $3.(*ast.UserSpec)) } AuthOption: @@ -4460,14 +4127,14 @@ AuthOption: } | "IDENTIFIED" "BY" AuthString { - $$ = &coldef.AuthOption { + $$ = &ast.AuthOption { AuthString: $3.(string), ByAuthString: true, - } + } } | "IDENTIFIED" "BY" "PASSWORD" HashString { - $$ = &coldef.AuthOption { + $$ = &ast.AuthOption{ HashString: $4.(string), } } @@ -4480,42 +4147,39 @@ HashString: * See: https://dev.mysql.com/doc/refman/5.7/en/grant.html *************************************************************************************/ GrantStmt: - "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecificationList + "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecList { - $$ = &stmts.GrantStmt{ - Privs: $2.([]*coldef.PrivElem), - ObjectType: $4.(int), - Level: $5.(*coldef.GrantLevel), - Users: $7.([]*coldef.UserSpecification), - } + $$ = &ast.GrantStmt{ + Privs: $2.([]*ast.PrivElem), + ObjectType: $4.(ast.ObjectTypeType), + Level: $5.(*ast.GrantLevel), + Users: $7.([]*ast.UserSpec), + } } PrivElem: - PrivType ColumnListOpt + PrivType { - $$ = &coldef.PrivElem{ + $$ = &ast.PrivElem{ Priv: $1.(mysql.PrivilegeType), - Cols: $2.([]string), } } - -ColumnListOpt: +| PrivType '(' ColumnNameList ')' { - $$ = []string{} - } -| '(' ColumnNameList ')' - { - $$ = $2 + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + Cols: $3.([]*ast.ColumnName), + } } PrivElemList: PrivElem { - $$ = []*coldef.PrivElem{$1.(*coldef.PrivElem)} + $$ = []*ast.PrivElem{$1.(*ast.PrivElem)} } | PrivElemList ',' PrivElem { - $$ = append($1.([]*coldef.PrivElem), $3.(*coldef.PrivElem)) + $$ = append($1.([]*ast.PrivElem), $3.(*ast.PrivElem)) } PrivType: @@ -4571,50 +4235,49 @@ PrivType: { $$ = mysql.GrantPriv } - + ObjectType: { - $$ = coldef.ObjectTypeNone + $$ = ast.ObjectTypeNone } | "TABLE" { - $$ = coldef.ObjectTypeTable + $$ = ast.ObjectTypeTable } PrivLevel: '*' { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelDB, - } + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + } } | '*' '.' '*' { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelGlobal, - } + $$ = &ast.GrantLevel { + Level: ast.GrantLevelGlobal, + } } | Identifier '.' '*' { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelDB, + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, DBName: $1.(string), - } + } } | Identifier '.' Identifier { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelTable, + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, DBName: $1.(string), TableName: $3.(string), - } + } } | Identifier { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelTable, + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, TableName: $1.(string), - } + } } %% - diff --git a/parser/parser_test.go b/parser/parser_test.go index 8f5c0d10ee..3e8d7bb278 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,16 +1,3 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - package parser import ( @@ -18,9 +5,7 @@ import ( "testing" . "github.com/pingcap/check" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/subquery" - "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/ast" ) func TestT(t *testing.T) { @@ -29,26 +14,11 @@ func TestT(t *testing.T) { var _ = Suite(&testParserSuite{}) +var _ = Suite(&testParserSuite{}) + type testParserSuite struct { } -func (s *testParserSuite) TestOriginText(c *C) { - src := `SELECT stuff.id - FROM stuff - WHERE stuff.value >= ALL (SELECT stuff.value - FROM stuff)` - - l := NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - node := l.Stmts()[0].(*stmts.SelectStmt) - sq := node.Where.Expr.(*expression.CompareSubQuery).R - c.Assert(sq, NotNil) - subsel := sq.(*subquery.SubQuery) - c.Assert(subsel.Stmt.OriginText(), Equals, - `SELECT stuff.value - FROM stuff`) -} - func (s *testParserSuite) TestSimple(c *C) { // Testcase for unreserved keywords unreservedKws := []string{ @@ -58,7 +28,7 @@ func (s *testParserSuite) TestSimple(c *C) { "start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown", "value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified", "collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size", - "max_rows", "min_rows", "national", "row", "quarter", "escape", "grants", + "max_rows", "min_rows", "national", "row", "quarter", "escape", } for _, kw := range unreservedKws { src := fmt.Sprintf("SELECT %s FROM tbl;", kw) @@ -70,15 +40,12 @@ func (s *testParserSuite) TestSimple(c *C) { // Testcase for prepared statement src := "SELECT id+?, id+? from t;" l := NewLexer(src) - l.SetPrepare() c.Assert(yyParse(l), Equals, 0) - c.Assert(len(l.ParamList), Equals, 2) c.Assert(len(l.Stmts()), Equals, 1) // Testcase for -- Comment and unary -- operator src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;" l = NewLexer(src) - l.SetPrepare() c.Assert(yyParse(l), Equals, 0) c.Assert(len(l.Stmts()), Equals, 2) @@ -87,11 +54,12 @@ func (s *testParserSuite) TestSimple(c *C) { l = NewLexer(src) c.Assert(yyParse(l), Equals, 0) st := l.Stmts()[0] - ss, ok := st.(*stmts.SelectStmt) + ss, ok := st.(*ast.SelectStmt) c.Assert(ok, IsTrue) - cv, ok := ss.Fields[0].Expr.(*expression.FunctionCast) + c.Assert(len(ss.Fields.Fields), Equals, 1) + cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) c.Assert(ok, IsTrue) - c.Assert(cv.FunctionType, Equals, expression.ConvertFunction) + c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) // For query start with comment srcs := []string{ @@ -105,7 +73,7 @@ func (s *testParserSuite) TestSimple(c *C) { l = NewLexer(src) c.Assert(yyParse(l), Equals, 0) st = l.Stmts()[0] - ss, ok = st.(*stmts.SelectStmt) + ss, ok = st.(*ast.SelectStmt) c.Assert(ok, IsTrue) } } @@ -172,9 +140,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"REPLACE INTO foo () VALUES ()", true}, {"REPLACE INTO foo VALUE ()", true}, // 40 - {`SELECT stuff.id - FROM stuff - WHERE stuff.value >= ALL (SELECT stuff.value + {`SELECT stuff.id + FROM stuff + WHERE stuff.value >= ALL (SELECT stuff.value FROM stuff)`, true}, {"BEGIN", true}, {"START TRANSACTION", true}, diff --git a/parser/scanner.l b/parser/scanner.l index c1441b4650..9ee7b1b8a9 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -22,25 +22,25 @@ import ( "fmt" "math" "strconv" - "unicode" "strings" - - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/util/stringutil" - "github.com/pingcap/tidb/util/charset" + "unicode" + + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/charset" + "github.com/pingcap/tidb/util/stringutil" ) type lexer struct { c int col int errs []error - expr expression.Expression + expr ast.ExprNode i int inj int lcol int line int - list []interface{} + list []ast.StmtNode ncol int nline int sc int @@ -48,8 +48,7 @@ type lexer struct { val []byte ungetBuf []byte root bool - prepare bool - ParamList []*expression.ParamMarker + prepare bool stmtStartPos int stringLit []byte @@ -78,11 +77,11 @@ func (l *lexer) Errors() []error { return l.errs } -func (l *lexer) Stmts() []interface{}{ +func (l *lexer) Stmts() []ast.StmtNode { return l.list } -func (l *lexer) Expr() expression.Expression { +func (l *lexer) Expr() ast.ExprNode { return l.expr } @@ -90,16 +89,16 @@ func (l *lexer) Inj() int { return l.inj } +func (l *lexer) SetInj(inj int) { + l.inj = inj +} + func (l *lexer) SetPrepare() { - l.prepare = true + l.prepare = true } func (l *lexer) IsPrepare() bool { - return l.prepare -} - -func (l *lexer) SetInj(inj int) { - l.inj = inj + return l.prepare } func (l *lexer) Root() bool { @@ -116,7 +115,17 @@ func (l *lexer) SetCharsetInfo(charset, collation string) { } func (l *lexer) GetCharsetInfo() (string, string) { - return l.charset, l.collation + return l.charset, l.collation +} + +// The select statement is not at the end of the whole statement, if the last +// field text was set from its offset to the end of the src string, update +// the last field text. +func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 { + lastField.SetText(l.src[lastField.Offset:lastEnd]) + } } func (l *lexer) unget(b byte) { @@ -822,9 +831,9 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {repeat} lval.item = string(l.val) return repeat {regexp} return regexp -{references} return references {replace} lval.item = string(l.val) return replace +{references} return references {rlike} return rlike {sys_var} lval.item = string(l.val) @@ -1047,7 +1056,7 @@ func (l *lexer) str(lval *yySymType, pref string) int { pref = "\"" } v := stringutil.RemoveUselessBackslash(pref+s) - v, err := strconv.Unquote(v) + v, err := strconv.Unquote(v) if err != nil { v = strings.TrimSuffix(s, pref) } diff --git a/parser/scanner_test.go b/parser/scanner_test.go deleted file mode 100644 index 0d920cbd5a..0000000000 --- a/parser/scanner_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package parser - -import ( - "fmt" - "unicode" - - . "github.com/pingcap/check" -) - -func tok2name(i int) string { - if i == unicode.ReplacementChar { - return "" - } - - if i < 128 { - return fmt.Sprintf("tok-'%c'", i) - } - - return fmt.Sprintf("tok-%d", i) -} - -func (s *testParserSuite) TestScaner0(c *C) { - table := []struct { - src string - tok, line, col, nline, ncol int - val string - }{ - {"a", identifier, 1, 1, 1, 2, "a"}, - {" a", identifier, 1, 2, 1, 3, "a"}, - {"a ", identifier, 1, 1, 1, 2, "a"}, - {" a ", identifier, 1, 2, 1, 3, "a"}, - {"\na", identifier, 2, 1, 2, 2, "a"}, - - {"a\n", identifier, 1, 1, 1, 2, "a"}, - {"\na\n", identifier, 2, 1, 2, 2, "a"}, - {"\n a", identifier, 2, 2, 2, 3, "a"}, - {"a \n", identifier, 1, 1, 1, 2, "a"}, - {"\n a \n", identifier, 2, 2, 2, 3, "a"}, - - {"ab", identifier, 1, 1, 1, 3, "ab"}, - {" ab", identifier, 1, 2, 1, 4, "ab"}, - {"ab ", identifier, 1, 1, 1, 3, "ab"}, - {" ab ", identifier, 1, 2, 1, 4, "ab"}, - {"\nab", identifier, 2, 1, 2, 3, "ab"}, - - {"ab\n", identifier, 1, 1, 1, 3, "ab"}, - {"\nab\n", identifier, 2, 1, 2, 3, "ab"}, - {"\n ab", identifier, 2, 2, 2, 4, "ab"}, - {"ab \n", identifier, 1, 1, 1, 3, "ab"}, - {"\n ab \n", identifier, 2, 2, 2, 4, "ab"}, - - {"c", identifier, 1, 1, 1, 2, "c"}, - {"cR", identifier, 1, 1, 1, 3, "cR"}, - {"cRe", identifier, 1, 1, 1, 4, "cRe"}, - {"cReA", identifier, 1, 1, 1, 5, "cReA"}, - {"cReAt", identifier, 1, 1, 1, 6, "cReAt"}, - - {"cReATe", create, 1, 1, 1, 7, "cReATe"}, - {"cReATeD", identifier, 1, 1, 1, 8, "cReATeD"}, - {"2", intLit, 1, 1, 1, 2, "2"}, - {"2.", floatLit, 1, 1, 1, 3, "2."}, - {"2.3", floatLit, 1, 1, 1, 4, "2.3"}, - } - - lval := &yySymType{} - for _, t := range table { - l := NewLexer(t.src) - tok := l.Lex(lval) - nline, ncol := l.npos() - val := string(l.val) - - c.Assert(tok, Equals, t.tok) - c.Assert(l.line, Equals, t.line) - c.Assert(l.col, Equals, t.col) - c.Assert(nline, Equals, t.nline) - c.Assert(ncol, Equals, t.ncol) - c.Assert(val, Equals, t.val) - } -} diff --git a/tidb.go b/tidb.go index 27fcf5167a..30a95b608f 100644 --- a/tidb.go +++ b/tidb.go @@ -26,13 +26,13 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" From 88711021e8be1b654476f178bd12aa55ae850d5e Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 5 Nov 2015 16:13:29 +0800 Subject: [PATCH 27/27] parser: address comment. --- make.cmd | 11 ----------- parser/parser_test.go | 17 ++++++++++++++--- parser/scanner.l | 2 +- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/make.cmd b/make.cmd index 523ac82f76..83a2dbda67 100644 --- a/make.cmd +++ b/make.cmd @@ -18,17 +18,6 @@ DEL /F /A /Q y.output golex -o parser/scanner.go parser/scanner.l -@echo [Ast-Parser] -go get github.com/qiuyesuifeng/goyacc -go get github.com/qiuyesuifeng/golex -type nul >>temp.XXXXXX -goyacc -o nul -xegen "temp.XXXXXX" parser/parser.y -goyacc -o ast/parser/parser.go -xe "temp.XXXXXX" ast/parser/parser.y -DEL /F /A /Q temp.XXXXXX -DEL /F /A /Q y.output - -golex -o ast/parser/scanner.go ast/parser/scanner.l - @echo [Build] godep go build -ldflags '%LDFLAGS%' diff --git a/parser/parser_test.go b/parser/parser_test.go index 3e8d7bb278..50d6e5cf2d 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -1,3 +1,16 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + package parser import ( @@ -14,8 +27,6 @@ func TestT(t *testing.T) { var _ = Suite(&testParserSuite{}) -var _ = Suite(&testParserSuite{}) - type testParserSuite struct { } @@ -28,7 +39,7 @@ func (s *testParserSuite) TestSimple(c *C) { "start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown", "value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified", "collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size", - "max_rows", "min_rows", "national", "row", "quarter", "escape", + "max_rows", "min_rows", "national", "row", "quarter", "escape", "grants", } for _, kw := range unreservedKws { src := fmt.Sprintf("SELECT %s FROM tbl;", kw) diff --git a/parser/scanner.l b/parser/scanner.l index 9ee7b1b8a9..191eaef35e 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -1056,7 +1056,7 @@ func (l *lexer) str(lval *yySymType, pref string) int { pref = "\"" } v := stringutil.RemoveUselessBackslash(pref+s) - v, err := strconv.Unquote(v) + v, err := strconv.Unquote(v) if err != nil { v = strings.TrimSuffix(s, pref) }