diff --git a/Makefile b/Makefile index 871e02ac7b..3468b1758c 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ TARGET = "" .PHONY: godep deps all build install parser clean todo test gotest interpreter server -all: godep parser ast-parser build test check +all: godep parser build test check godep: go get github.com/tools/godep @@ -31,7 +31,7 @@ parser: goyacc -o /dev/null -xegen $$a parser/parser.y; \ goyacc -o parser/parser.go -xe $$a parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \ rm -f $$a; \ - rm -f y.output + rm -f y.output @if [ $(ARCH) = $(LINUX) ]; \ then \ @@ -44,24 +44,6 @@ parser: golex -o parser/scanner.go parser/scanner.l -ast-parser: - a=`mktemp temp.XXXXXX`; \ - goyacc -o /dev/null -xegen $$a ast/parser/parser.y; \ - goyacc -o ast/parser/parser.go -xe $$a ast/parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \ - rm -f $$a; \ - rm -f y.output - - @if [ $(ARCH) = $(LINUX) ]; \ - then \ - sed -i -e 's|//line.*||' -e 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \ - elif [ $(ARCH) = $(MAC) ]; \ - then \ - /usr/bin/sed -i "" 's|//line.*||' ast/parser/parser.go; \ - /usr/bin/sed -i "" 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \ - fi - - golex -o ast/parser/scanner.go ast/parser/scanner.l - check: go get github.com/golang/lint/golint diff --git a/ast/ast.go b/ast/ast.go index 0e9ab19893..c38c8bc32d 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -26,6 +26,11 @@ type Node interface { // Accept accepts Visitor to visit itself. // The returned node should replace original node. // ok returns false to stop visiting. + // + // Implementation of this method should first call visitor.Enter, + // assign the returned node to its method receiver, if skipChildren returns true, + // children should be skipped. Otherwise, call its children in particular order that + // later elements depends on former elements. Finally, return visitor.Leave. Accept(v Visitor) (node Node, ok bool) // Text returns the original text of the element. Text() string @@ -44,6 +49,10 @@ type ExprNode interface { SetType(tp *types.FieldType) // GetType gets the evaluation type of the expression. GetType() *types.FieldType + // SetValue sets value to the expression. + SetValue(val interface{}) + // GetValue gets value of the expression. + GetValue() interface{} } // FuncNode represents function call expression node. @@ -93,11 +102,12 @@ type ResultSetNode interface { // Visitor visits a Node. type Visitor interface { // VisitEnter is called before children nodes is visited. + // The returned node must be the same type as the input node n. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. - // ok returns false to stop visiting. - Enter(n Node) (skipChildren bool, ok bool) + Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. + // The returned node must be the same type as the input node n. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) } diff --git a/ast/base.go b/ast/base.go index 50e9ef030c..877fa80b51 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,7 +62,7 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node - tp *types.FieldType + types.DataItem } // IsStatic implements Expression interface. @@ -72,12 +72,22 @@ func (en *exprNode) IsStatic() bool { // SetType implements Expression interface. func (en *exprNode) SetType(tp *types.FieldType) { - en.tp = tp + en.Type = tp } // GetType implements Expression interface. func (en *exprNode) GetType() *types.FieldType { - return en.tp + return en.Type +} + +// SetValue implements Expression interface. +func (en *exprNode) SetValue(val interface{}) { + en.Data = val +} + +// GetValue implements Expression interface. +func (en *exprNode) GetValue() interface{} { + return en.Data } type funcNode struct { diff --git a/ast/cloner.go b/ast/cloner.go new file mode 100644 index 0000000000..d3d8d282e5 --- /dev/null +++ b/ast/cloner.go @@ -0,0 +1,173 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import "fmt" + +// Cloner is an ast visitor that clones a node. +type Cloner struct { +} + +// Enter implements Visitor Enter interface. +func (c *Cloner) Enter(node Node) (Node, bool) { + return copyStruct(node), false +} + +// Leave implements Visitor Leave interface. +func (c *Cloner) Leave(in Node) (out Node, ok bool) { + return in, true +} + +// copyStruct copies a node's struct value, if the struct has slice member, +// make a new slice and copy old slice value to new slice. +func copyStruct(in Node) (out Node) { + switch v := in.(type) { + case *ValueExpr: + nv := *v + out = &nv + case *BetweenExpr: + nv := *v + out = &nv + case *BinaryOperationExpr: + nv := *v + out = &nv + case *WhenClause: + nv := *v + out = &nv + case *CaseExpr: + nv := *v + nv.WhenClauses = make([]*WhenClause, len(v.WhenClauses)) + copy(nv.WhenClauses, v.WhenClauses) + out = &nv + case *SubqueryExpr: + nv := *v + out = &nv + case *CompareSubqueryExpr: + nv := *v + out = &nv + case *ColumnName: + nv := *v + out = &nv + case *ColumnNameExpr: + nv := *v + out = &nv + case *DefaultExpr: + nv := *v + out = &nv + case *IdentifierExpr: + nv := *v + out = &nv + case *ExistsSubqueryExpr: + nv := *v + out = &nv + case *PatternInExpr: + nv := *v + nv.List = make([]ExprNode, len(v.List)) + copy(nv.List, v.List) + out = &nv + case *IsNullExpr: + nv := *v + out = &nv + case *IsTruthExpr: + nv := *v + out = &nv + case *PatternLikeExpr: + nv := *v + out = &nv + case *ParamMarkerExpr: + nv := *v + out = &nv + case *ParenthesesExpr: + nv := *v + out = &nv + case *PositionExpr: + nv := *v + out = &nv + case *PatternRegexpExpr: + nv := *v + out = &nv + case *RowExpr: + nv := *v + nv.Values = make([]ExprNode, len(v.Values)) + copy(nv.Values, v.Values) + out = &nv + case *UnaryOperationExpr: + nv := *v + out = &nv + case *ValuesExpr: + nv := *v + out = &nv + case *VariableExpr: + nv := *v + out = &nv + case *Join: + nv := *v + out = &nv + case *TableName: + nv := *v + out = &nv + case *TableSource: + nv := *v + out = &nv + case *OnCondition: + nv := *v + out = &nv + case *WildCardField: + nv := *v + out = &nv + case *SelectField: + nv := *v + out = &nv + case *FieldList: + nv := *v + nv.Fields = make([]*SelectField, len(v.Fields)) + copy(nv.Fields, v.Fields) + out = &nv + case *TableRefsClause: + nv := *v + out = &nv + case *ByItem: + nv := *v + out = &nv + case *GroupByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *HavingClause: + nv := *v + out = &nv + case *OrderByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *SelectStmt: + nv := *v + out = &nv + case *UnionClause: + nv := *v + out = &nv + case *UnionStmt: + nv := *v + nv.Selects = make([]*SelectStmt, len(v.Selects)) + copy(nv.Selects, v.Selects) + out = &nv + default: + // We currently only handle expression and select statement. + // Will add more when we need to. + panic("unknown ast Node type " + fmt.Sprintf("%T", v)) + } + return +} diff --git a/ast/cloner_test.go b/ast/cloner_test.go new file mode 100644 index 0000000000..f6f567205b --- /dev/null +++ b/ast/cloner_test.go @@ -0,0 +1,40 @@ +package ast + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/parser/opcode" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testClonerSuite{}) + +type testClonerSuite struct { +} + +func (ts *testClonerSuite) TestCloner(c *C) { + cloner := &Cloner{} + + a := &UnaryOperationExpr{ + Op: opcode.Not, + V: &UnaryOperationExpr{V: NewValueExpr(true)}, + } + + b, ok := a.Accept(cloner) + c.Assert(ok, IsTrue) + a1 := a.V + b1 := b.(*UnaryOperationExpr).V + c.Assert(a1, Not(Equals), b1) + a2 := a1.(*UnaryOperationExpr).V + b2 := b1.(*UnaryOperationExpr).V + c.Assert(a2, Not(Equals), b2) + a3 := a2.(*ValueExpr) + b3 := b2.(*ValueExpr) + c.Assert(a3, Not(Equals), b3) + c.Assert(a3.GetValue(), Equals, true) + c.Assert(b3.GetValue(), Equals, true) +} diff --git a/ast/ddl.go b/ast/ddl.go index 0e43b69552..4f655d6f01 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -70,11 +70,13 @@ type CreateDatabaseStmt struct { } // Accept implements Node Accept interface. -func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cd) + n = newNod.(*CreateDatabaseStmt) + return v.Leave(n) } // DropDatabaseStmt is a statement to drop a database and all tables in the database. @@ -87,11 +89,13 @@ type DropDatabaseStmt struct { } // Accept implements Node Accept interface. -func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dd); skipChildren { - return dd, ok +func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(dd) + n = newNod.(*DropDatabaseStmt) + return v.Leave(n) } // IndexColName is used for parsing index column name from SQL. @@ -103,16 +107,18 @@ type IndexColName struct { } // Accept implements Node Accept interface. -func (ic *IndexColName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ic); skipChildren { - return ic, ok +func (n *IndexColName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ic.Column.Accept(v) + n = newNod.(*IndexColName) + node, ok := n.Column.Accept(v) if !ok { - return ic, false + return n, false } - ic.Column = node.(*ColumnName) - return v.Leave(ic) + n.Column = node.(*ColumnName) + return v.Leave(n) } // ReferenceDef is used for parsing foreign key reference option from SQL. @@ -125,23 +131,25 @@ type ReferenceDef struct { } // Accept implements Node Accept interface. -func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rd); skipChildren { - return rd, ok +func (n *ReferenceDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := rd.Table.Accept(v) + n = newNod.(*ReferenceDef) + node, ok := n.Table.Accept(v) if !ok { - return rd, false + return n, false } - rd.Table = node.(*TableName) - for i, val := range rd.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return rd, false + return n, false } - rd.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(rd) + return v.Leave(n) } // ColumnOptionType is the type for ColumnOption. @@ -175,18 +183,20 @@ type ColumnOption struct { } // Accept implements Node Accept interface. -func (co *ColumnOption) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(co); skipChildren { - return co, ok +func (n *ColumnOption) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if co.Expr != nil { - node, ok := co.Expr.Accept(v) + n = newNod.(*ColumnOption) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return co, false + return n, false } - co.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(co) + return v.Leave(n) } // ConstraintType is the type for Constraint. @@ -220,25 +230,27 @@ type Constraint struct { } // Accept implements Node Accept interface. -func (tc *Constraint) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tc); skipChildren { - return tc, ok +func (n *Constraint) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range tc.Keys { + n = newNod.(*Constraint) + for i, val := range n.Keys { node, ok := val.Accept(v) if !ok { - return tc, false + return n, false } - tc.Keys[i] = node.(*IndexColName) + n.Keys[i] = node.(*IndexColName) } - if tc.Refer != nil { - node, ok := tc.Refer.Accept(v) + if n.Refer != nil { + node, ok := n.Refer.Accept(v) if !ok { - return tc, false + return n, false } - tc.Refer = node.(*ReferenceDef) + n.Refer = node.(*ReferenceDef) } - return v.Leave(tc) + return v.Leave(n) } // ColumnDef is used for parsing column definition from SQL. @@ -251,23 +263,25 @@ type ColumnDef struct { } // Accept implements Node Accept interface. -func (cd *ColumnDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (n *ColumnDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := cd.Name.Accept(v) + n = newNod.(*ColumnDef) + node, ok := n.Name.Accept(v) if !ok { - return cd, false + return n, false } - cd.Name = node.(*ColumnName) - for i, val := range cd.Options { + n.Name = node.(*ColumnName) + for i, val := range n.Options { node, ok := val.Accept(v) if !ok { - return cd, false + return n, false } - cd.Options[i] = node.(*ColumnOption) + n.Options[i] = node.(*ColumnOption) } - return v.Leave(cd) + return v.Leave(n) } // CreateTableStmt is a statement to create a table. @@ -283,30 +297,32 @@ type CreateTableStmt struct { } // Accept implements Node Accept interface. -func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ct); skipChildren { - return ct, ok +func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ct.Table.Accept(v) + n = newNod.(*CreateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return ct, false + return n, false } - ct.Table = node.(*TableName) - for i, val := range ct.Cols { + n.Table = node.(*TableName) + for i, val := range n.Cols { node, ok = val.Accept(v) if !ok { - return ct, false + return n, false } - ct.Cols[i] = node.(*ColumnDef) + n.Cols[i] = node.(*ColumnDef) } - for i, val := range ct.Constraints { + for i, val := range n.Constraints { node, ok = val.Accept(v) if !ok { - return ct, false + return n, false } - ct.Constraints[i] = node.(*Constraint) + n.Constraints[i] = node.(*Constraint) } - return v.Leave(ct) + return v.Leave(n) } // DropTableStmt is a statement to drop one or more tables. @@ -319,18 +335,20 @@ type DropTableStmt struct { } // Accept implements Node Accept interface. -func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dt); skipChildren { - return dt, ok +func (n *DropTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range dt.Tables { + n = newNod.(*DropTableStmt) + for i, val := range n.Tables { node, ok := val.Accept(v) if !ok { - return dt, false + return n, false } - dt.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - return v.Leave(dt) + return v.Leave(n) } // CreateIndexStmt is a statement to create an index. @@ -345,23 +363,25 @@ type CreateIndexStmt struct { } // Accept implements Node Accept interface. -func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ci); skipChildren { - return ci, ok +func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ci.Table.Accept(v) + n = newNod.(*CreateIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return ci, false + return n, false } - ci.Table = node.(*TableName) - for i, val := range ci.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return ci, false + return n, false } - ci.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(ci) + return v.Leave(n) } // DropIndexStmt is a statement to drop the index. @@ -375,16 +395,18 @@ type DropIndexStmt struct { } // Accept implements Node Accept interface. -func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(di); skipChildren { - return di, ok +func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := di.Table.Accept(v) + n = newNod.(*DropIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return di, false + return n, false } - di.Table = node.(*TableName) - return v.Leave(di) + n.Table = node.(*TableName) + return v.Leave(n) } // TableOptionType is the type for TableOption @@ -435,18 +457,20 @@ type ColumnPosition struct { } // Accept implements Node Accept interface. -func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cp); skipChildren { - return cp, ok +func (n *ColumnPosition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if cp.RelativeColumn != nil { - node, ok := cp.RelativeColumn.Accept(v) + n = newNod.(*ColumnPosition) + if n.RelativeColumn != nil { + node, ok := n.RelativeColumn.Accept(v) if !ok { - return cp, false + return n, false } - cp.RelativeColumn = node.(*ColumnName) + n.RelativeColumn = node.(*ColumnName) } - return v.Leave(cp) + return v.Leave(n) } // AlterTableType is the type for AlterTableSpec. @@ -474,44 +498,46 @@ type AlterTableSpec struct { Constraint *Constraint Options []*TableOption Column *ColumnDef - ColumnName *ColumnName + DropColumn *ColumnName Position *ColumnPosition } // Accept implements Node Accept interface. -func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if as.Constraint != nil { - node, ok := as.Constraint.Accept(v) + n = newNod.(*AlterTableSpec) + if n.Constraint != nil { + node, ok := n.Constraint.Accept(v) if !ok { - return as, false + return n, false } - as.Constraint = node.(*Constraint) + n.Constraint = node.(*Constraint) } - if as.Column != nil { - node, ok := as.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return as, false + return n, false } - as.Column = node.(*ColumnDef) + n.Column = node.(*ColumnDef) } - if as.ColumnName != nil { - node, ok := as.ColumnName.Accept(v) + if n.DropColumn != nil { + node, ok := n.DropColumn.Accept(v) if !ok { - return as, false + return n, false } - as.ColumnName = node.(*ColumnName) + n.DropColumn = node.(*ColumnName) } - if as.Position != nil { - node, ok := as.Position.Accept(v) + if n.Position != nil { + node, ok := n.Position.Accept(v) if !ok { - return as, false + return n, false } - as.Position = node.(*ColumnPosition) + n.Position = node.(*ColumnPosition) } - return v.Leave(as) + return v.Leave(n) } // AlterTableStmt is a statement to change the structure of a table. @@ -524,23 +550,25 @@ type AlterTableStmt struct { } // Accept implements Node Accept interface. -func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(at); skipChildren { - return at, ok +func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := at.Table.Accept(v) + n = newNod.(*AlterTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return at, false + return n, false } - at.Table = node.(*TableName) - for i, val := range at.Specs { + n.Table = node.(*TableName) + for i, val := range n.Specs { node, ok = val.Accept(v) if !ok { - return at, false + return n, false } - at.Specs[i] = node.(*AlterTableSpec) + n.Specs[i] = node.(*AlterTableSpec) } - return v.Leave(at) + return v.Leave(n) } // TruncateTableStmt is a statement to empty a table completely. @@ -552,14 +580,16 @@ type TruncateTableStmt struct { } // Accept implements Node Accept interface. -func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Table.Accept(v) + n = newNod.(*TruncateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return ts, false + return n, false } - ts.Table = node.(*TableName) - return v.Leave(ts) + n.Table = node.(*TableName) + return v.Leave(n) } diff --git a/ast/dml.go b/ast/dml.go index b8ada81d99..e15b5f83bb 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -22,6 +22,7 @@ var ( _ DMLNode = &DeleteStmt{} _ DMLNode = &UpdateStmt{} _ DMLNode = &SelectStmt{} + _ DMLNode = &UnionStmt{} _ Node = &Join{} _ Node = &TableName{} _ Node = &TableSource{} @@ -55,34 +56,36 @@ type Join struct { // Tp represents join type. Tp JoinType // On represents join on condition. - On ExprNode + On *OnCondition } // Accept implements Node Accept interface. -func (j *Join) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(j); skipChildren { - return j, ok +func (n *Join) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := j.Left.Accept(v) + n = newNod.(*Join) + node, ok := n.Left.Accept(v) if !ok { - return j, false + return n, false } - j.Left = node.(ResultSetNode) - if j.Right != nil { - node, ok = j.Right.Accept(v) + n.Left = node.(ResultSetNode) + if n.Right != nil { + node, ok = n.Right.Accept(v) if !ok { - return j, false + return n, false } - j.Right = node.(ResultSetNode) + n.Right = node.(ResultSetNode) } - if j.On != nil { - node, ok = j.On.Accept(v) + if n.On != nil { + node, ok = n.On.Accept(v) if !ok { - return j, false + return n, false } - j.On = node.(ExprNode) + n.On = node.(*OnCondition) } - return v.Leave(j) + return v.Leave(n) } // TableName represents a table name. @@ -98,11 +101,13 @@ type TableName struct { } // Accept implements Node Accept interface. -func (tr *TableName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tr); skipChildren { - return tr, ok +func (n *TableName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(tr) + n = newNod.(*TableName) + return v.Leave(n) } // TableSource represents table source with a name. @@ -110,7 +115,7 @@ type TableSource struct { node // Source is the source of the data, can be a TableName, - // a SubQuery, or a JoinNode. + // a SelectStmt, a UnionStmt, or a JoinNode. Source ResultSetNode // AsName is the as name of the table source. @@ -118,47 +123,50 @@ type TableSource struct { } // Accept implements Node Accept interface. -func (ts *TableSource) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (n *TableSource) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Source.Accept(v) + n = newNod.(*TableSource) + node, ok := n.Source.Accept(v) if !ok { - return ts, false + return n, false } - ts.Source = node.(ResultSetNode) - return v.Leave(ts) + n.Source = node.(ResultSetNode) + return v.Leave(n) } -// SetResultFields implements ResultSet interface. -func (ts *TableSource) SetResultFields(rfs []*ResultField) { - ts.Source.SetResultFields(rfs) -} - -// GetResultFields implements ResultSet interface. -func (ts *TableSource) GetResultFields() []*ResultField { - return ts.Source.GetResultFields() -} - -// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. -type UnionClause struct { +// OnCondition represetns JOIN on condition. +type OnCondition struct { node - Distinct bool - Select *SelectStmt + Expr ExprNode } // Accept implements Node Accept interface. -func (uc *UnionClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(uc); skipChildren { - return uc, ok +func (n *OnCondition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := uc.Select.Accept(v) + n = newNod.(*OnCondition) + node, ok := n.Expr.Accept(v) if !ok { - return uc, false + return n, false } - uc.Select = node.(*SelectStmt) - return v.Leave(uc) + n.Expr = node.(ExprNode) + return v.Leave(n) +} + +// SetResultFields implements ResultSet interface. +func (n *TableSource) SetResultFields(rfs []*ResultField) { + n.Source.SetResultFields(rfs) +} + +// GetResultFields implements ResultSet interface. +func (n *TableSource) GetResultFields() []*ResultField { + return n.Source.GetResultFields() } // SelectLockType is the lock type for SelectStmt. @@ -175,22 +183,18 @@ const ( type WildCardField struct { node - Table *TableName + Table model.CIStr + Schema model.CIStr } // Accept implements Node Accept interface. -func (wf *WildCardField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(wf); skipChildren { - return wf, ok +func (n *WildCardField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if wf.Table != nil { - node, ok := wf.Table.Accept(v) - if !ok { - return wf, false - } - wf.Table = node.(*TableName) - } - return v.Leave(wf) + n = newNod.(*WildCardField) + return v.Leave(n) } // SelectField represents fields in select statement. @@ -199,6 +203,8 @@ func (wf *WildCardField) Accept(v Visitor) (Node, bool) { type SelectField struct { node + // Offset is used to get original text. + Offset int // If WildCard is not nil, Expr will be nil. WildCard *WildCardField // If Expr is not nil, WildCard will be nil. @@ -208,22 +214,70 @@ type SelectField struct { } // Accept implements Node Accept interface. -func (sf *SelectField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (n *SelectField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if sf.Expr != nil { - node, ok := sf.Expr.Accept(v) + n = newNod.(*SelectField) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return sf, false + return n, false } - sf.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(sf) + return v.Leave(n) } -// OrderByItem represents a single order by item. -type OrderByItem struct { +// FieldList represents field list in select statement. +type FieldList struct { + node + + Fields []*SelectField +} + +// Accept implements Node Accept interface. +func (n *FieldList) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*FieldList) + for i, val := range n.Fields { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Fields[i] = node.(*SelectField) + } + return v.Leave(n) +} + +// TableRefsClause represents table references clause in dml statement. +type TableRefsClause struct { + node + + TableRefs *Join +} + +// Accept implements Node Accept interface. +func (n *TableRefsClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*TableRefsClause) + node, ok := n.TableRefs.Accept(v) + if !ok { + return n, false + } + n.TableRefs = node.(*Join) + return v.Leave(n) +} + +// ByItem represents an item in order by or group by. +type ByItem struct { node Expr ExprNode @@ -231,131 +285,244 @@ type OrderByItem struct { } // Accept implements Node Accept interface. -func (ob *OrderByItem) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ob); skipChildren { - return ob, ok +func (n *ByItem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ob.Expr.Accept(v) + n = newNod.(*ByItem) + node, ok := n.Expr.Accept(v) if !ok { - return ob, false + return n, false } - ob.Expr = node.(ExprNode) - return v.Leave(ob) + n.Expr = node.(ExprNode) + return v.Leave(n) +} + +// GroupByClause represents group by clause. +type GroupByClause struct { + node + Items []*ByItem +} + +// Accept implements Node Accept interface. +func (n *GroupByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*GroupByClause) + for i, val := range n.Items { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Items[i] = node.(*ByItem) + } + return v.Leave(n) +} + +// HavingClause represents having clause. +type HavingClause struct { + node + Expr ExprNode +} + +// Accept implements Node Accept interface. +func (n *HavingClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*HavingClause) + node, ok := n.Expr.Accept(v) + if !ok { + return n, false + } + n.Expr = node.(ExprNode) + return v.Leave(n) +} + +// OrderByClause represents order by clause. +type OrderByClause struct { + node + Items []*ByItem + ForUnion bool +} + +// Accept implements Node Accept interface. +func (n *OrderByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*OrderByClause) + for i, val := range n.Items { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Items[i] = node.(*ByItem) + } + return v.Leave(n) } // SelectStmt represents the select query node. +// See: https://dev.mysql.com/doc/refman/5.7/en/select.html type SelectStmt struct { dmlNode resultSetNode // Distinct represents if the select has distinct option. Distinct bool - // Fields is the select expression list. - Fields []*SelectField // From is the from clause of the query. - From *Join + From *TableRefsClause // Where is the where clause in select statement. Where ExprNode + // Fields is the select expression list. + Fields *FieldList // GroupBy is the group by expression list. - GroupBy []ExprNode + GroupBy *GroupByClause // Having is the having condition. - Having ExprNode - // OrderBy is the odering expression list. - OrderBy []*OrderByItem + Having *HavingClause + // OrderBy is the ordering expression list. + OrderBy *OrderByClause // Limit is the limit clause. Limit *Limit // Lock is the lock type LockTp SelectLockType - - // Union clauses. - Unions []*UnionClause - // Order by for union select. - UnionOrderBy []*OrderByItem - // Limit for union select. - UnionLimit *Limit } // Accept implements Node Accept interface. -func (sn *SelectStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sn); skipChildren { - return sn, ok +func (n *SelectStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range sn.Fields { - node, ok := val.Accept(v) + n = newNod.(*SelectStmt) + + if n.From != nil { + node, ok := n.From.Accept(v) if !ok { - return sn, false + return n, false } - sn.Fields[i] = node.(*SelectField) - } - if sn.From != nil { - node, ok := sn.From.Accept(v) - if !ok { - return sn, false - } - sn.From = node.(*Join) + n.From = node.(*TableRefsClause) } - if sn.Where != nil { - node, ok := sn.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return sn, false + return n, false } - sn.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - for i, val := range sn.GroupBy { - node, ok := val.Accept(v) + if n.Fields != nil { + node, ok := n.Fields.Accept(v) if !ok { - return sn, false + return n, false } - sn.GroupBy[i] = node.(ExprNode) - } - if sn.Having != nil { - node, ok := sn.Having.Accept(v) - if !ok { - return sn, false - } - sn.Having = node.(ExprNode) + n.Fields = node.(*FieldList) } - for i, val := range sn.OrderBy { - node, ok := val.Accept(v) + if n.GroupBy != nil { + node, ok := n.GroupBy.Accept(v) if !ok { - return sn, false + return n, false } - sn.OrderBy[i] = node.(*OrderByItem) + n.GroupBy = node.(*GroupByClause) } - if sn.Limit != nil { - node, ok := sn.Limit.Accept(v) + if n.Having != nil { + node, ok := n.Having.Accept(v) if !ok { - return sn, false + return n, false } - sn.Limit = node.(*Limit) + n.Having = node.(*HavingClause) } - for i, val := range sn.Unions { + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) + if !ok { + return n, false + } + n.OrderBy = node.(*OrderByClause) + } + + if n.Limit != nil { + node, ok := n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) + } + return v.Leave(n) +} + +// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. +type UnionClause struct { + node + + Distinct bool + Select *SelectStmt +} + +// Accept implements Node Accept interface. +func (n *UnionClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*UnionClause) + node, ok := n.Select.Accept(v) + if !ok { + return n, false + } + n.Select = node.(*SelectStmt) + return v.Leave(n) +} + +// UnionStmt represents "union statement" +// See: https://dev.mysql.com/doc/refman/5.7/en/union.html +type UnionStmt struct { + dmlNode + resultSetNode + + Distinct bool + Selects []*SelectStmt + OrderBy *OrderByClause + Limit *Limit +} + +// Accept implements Node Accept interface. +func (n *UnionStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*UnionStmt) + for i, val := range n.Selects { node, ok := val.Accept(v) if !ok { - return sn, false + return n, false } - sn.Unions[i] = node.(*UnionClause) + n.Selects[i] = node.(*SelectStmt) } - for i, val := range sn.UnionOrderBy { - node, ok := val.Accept(v) + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) if !ok { - return sn, false + return n, false } - sn.UnionOrderBy[i] = node.(*OrderByItem) + n.OrderBy = node.(*OrderByClause) } - if sn.UnionLimit != nil { - node, ok := sn.UnionLimit.Accept(v) + if n.Limit != nil { + node, ok := n.Limit.Accept(v) if !ok { - return sn, false + return n, false } - sn.UnionLimit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(sn) + return v.Leave(n) } // Assignment is the expression for assignment, like a = 1. @@ -368,21 +535,23 @@ type Assignment struct { } // Accept implements Node Accept interface. -func (as *Assignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (n *Assignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := as.Column.Accept(v) + n = newNod.(*Assignment) + node, ok := n.Column.Accept(v) if !ok { - return as, false + return n, false } - as.Column = node.(*ColumnName) - node, ok = as.Expr.Accept(v) + n.Column = node.(*ColumnName) + node, ok = n.Expr.Accept(v) if !ok { - return as, false + return n, false } - as.Expr = node.(ExprNode) - return v.Leave(as) + n.Expr = node.(ExprNode) + return v.Leave(n) } // Priority const values. @@ -399,58 +568,68 @@ const ( type InsertStmt struct { dmlNode + Replace bool + Table *TableRefsClause Columns []*ColumnName Lists [][]ExprNode - Table *TableName Setlist []*Assignment Priority int OnDuplicate []*Assignment - Select *SelectStmt + Select ResultSetNode } // Accept implements Node Accept interface. -func (in *InsertStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(in); skipChildren { - return in, ok +func (n *InsertStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range in.Columns { + n = newNod.(*InsertStmt) + + if n.Select != nil { + node, ok := n.Select.Accept(v) + if !ok { + return n, false + } + n.Select = node.(ResultSetNode) + } + node, ok := n.Table.Accept(v) + if !ok { + return n, false + } + n.Table = node.(*TableRefsClause) + + for i, val := range n.Columns { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.Columns[i] = node.(*ColumnName) + n.Columns[i] = node.(*ColumnName) } - for i, list := range in.Lists { + for i, list := range n.Lists { for j, val := range list { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.Lists[i][j] = node.(ExprNode) + n.Lists[i][j] = node.(ExprNode) } } - for i, val := range in.Setlist { + for i, val := range n.Setlist { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.Setlist[i] = node.(*Assignment) + n.Setlist[i] = node.(*Assignment) } - for i, val := range in.OnDuplicate { + for i, val := range n.OnDuplicate { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.OnDuplicate[i] = node.(*Assignment) + n.OnDuplicate[i] = node.(*Assignment) } - if in.Select != nil { - node, ok := in.Select.Accept(v) - if !ok { - return in, false - } - in.Select = node.(*SelectStmt) - } - return v.Leave(in) + return v.Leave(n) } // DeleteStmt is a statement to delete rows from table. @@ -458,10 +637,12 @@ func (in *InsertStmt) Accept(v Visitor) (Node, bool) { type DeleteStmt struct { dmlNode - TableRefs *Join + // Used in both single table and multiple table delete statement. + TableRefs *TableRefsClause + // Only used in multiple table delete statement. Tables []*TableName Where ExprNode - Order []*OrderByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -471,47 +652,49 @@ type DeleteStmt struct { } // Accept implements Node Accept interface. -func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(de); skipChildren { - return de, ok +func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } + n = newNod.(*DeleteStmt) - node, ok := de.TableRefs.Accept(v) + node, ok := n.TableRefs.Accept(v) if !ok { - return de, false + return n, false } - de.TableRefs = node.(*Join) + n.TableRefs = node.(*TableRefsClause) - for i, val := range de.Tables { + for i, val := range n.Tables { node, ok = val.Accept(v) if !ok { - return de, false + return n, false } - de.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - if de.Where != nil { - node, ok = de.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return de, false + return n, false } - de.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - - for i, val := range de.Order { - node, ok = val.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return de, false + return n, false } - de.Order[i] = node.(*OrderByItem) + n.Order = node.(*OrderByClause) } - - node, ok = de.Limit.Accept(v) - if !ok { - return de, false + if n.Limit != nil { + node, ok = n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) } - de.Limit = node.(*Limit) - return v.Leave(de) + return v.Leave(n) } // UpdateStmt is a statement to update columns of existing rows in tables with new values. @@ -519,10 +702,10 @@ func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { type UpdateStmt struct { dmlNode - TableRefs *Join + TableRefs *TableRefsClause List []*Assignment Where ExprNode - Order []*OrderByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -530,43 +713,46 @@ type UpdateStmt struct { } // Accept implements Node Accept interface. -func (up *UpdateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(up); skipChildren { - return up, ok +func (n *UpdateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := up.TableRefs.Accept(v) + n = newNod.(*UpdateStmt) + node, ok := n.TableRefs.Accept(v) if !ok { - return up, false + return n, false } - up.TableRefs = node.(*Join) - for i, val := range up.List { + n.TableRefs = node.(*TableRefsClause) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return up, false + return n, false } - up.List[i] = node.(*Assignment) + n.List[i] = node.(*Assignment) } - if up.Where != nil { - node, ok = up.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return up, false + return n, false } - up.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - - for i, val := range up.Order { - node, ok = val.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return up, false + return n, false } - up.Order[i] = node.(*OrderByItem) + n.Order = node.(*OrderByClause) } - node, ok = up.Limit.Accept(v) - if !ok { - return up, false + if n.Limit != nil { + node, ok = n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) } - up.Limit = node.(*Limit) - return v.Leave(up) + return v.Leave(n) } // Limit is the limit clause. @@ -578,9 +764,11 @@ type Limit struct { } // Accept implements Node Accept interface. -func (l *Limit) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(l); skipChildren { - return l, ok +func (n *Limit) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(l) + n = newNod.(*Limit) + return v.Leave(n) } diff --git a/ast/expressions.go b/ast/expressions.go index 558f1a019c..c3278bd331 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -14,8 +14,11 @@ package ast import ( + "fmt" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" ) var ( @@ -26,6 +29,7 @@ var ( _ ExprNode = &CaseExpr{} _ ExprNode = &SubqueryExpr{} _ ExprNode = &CompareSubqueryExpr{} + _ Node = &ColumnName{} _ ExprNode = &ColumnNameExpr{} _ ExprNode = &DefaultExpr{} _ ExprNode = &IdentifierExpr{} @@ -47,21 +51,58 @@ var ( // ValueExpr is the simple value expression. type ValueExpr struct { exprNode - // Val is the literal value. - Val interface{} +} + +// NewValueExpr creates a ValueExpr with value, and sets default field type. +func NewValueExpr(value interface{}) *ValueExpr { + ve := &ValueExpr{} + ve.Data = types.RawData(value) + // TODO: make it more precise. + switch value.(type) { + case nil: + ve.Type = types.NewFieldType(mysql.TypeNull) + case bool, int64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + case uint64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + ve.Type.Flag |= mysql.UnsignedFlag + case string, UnquoteString: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = mysql.DefaultCharset + ve.Type.Collate = mysql.DefaultCollationName + case float64: + ve.Type = types.NewFieldType(mysql.TypeDouble) + case []byte: + ve.Type = types.NewFieldType(mysql.TypeBlob) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + case mysql.Bit: + ve.Type = types.NewFieldType(mysql.TypeBit) + case mysql.Hex: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + case *types.DataItem: + ve.Type = value.(*types.DataItem).Type + default: + panic(fmt.Sprintf("illegal literal value type:%T", value)) + } + return ve } // IsStatic implements ExprNode interface. -func (val *ValueExpr) IsStatic() bool { +func (n *ValueExpr) IsStatic() bool { return true } // Accept implements Node interface. -func (val *ValueExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(val); skipChildren { - return val, ok +func (n *ValueExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(val) + n = newNod.(*ValueExpr) + return v.Leave(n) } // BetweenExpr is for "between and" or "not between and" expression. @@ -78,35 +119,37 @@ type BetweenExpr struct { } // Accept implements Node interface. -func (b *BetweenExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(b); skipChildren { - return b, ok +func (n *BetweenExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } + n = newNod.(*BetweenExpr) - node, ok := b.Expr.Accept(v) + node, ok := n.Expr.Accept(v) if !ok { - return b, false + return n, false } - b.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = b.Left.Accept(v) + node, ok = n.Left.Accept(v) if !ok { - return b, false + return n, false } - b.Left = node.(ExprNode) + n.Left = node.(ExprNode) - node, ok = b.Right.Accept(v) + node, ok = n.Right.Accept(v) if !ok { - return b, false + return n, false } - b.Right = node.(ExprNode) + n.Right = node.(ExprNode) - return v.Leave(b) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (b *BetweenExpr) IsStatic() bool { - return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic() +func (n *BetweenExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Left.IsStatic() && n.Right.IsStatic() } // BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. @@ -121,29 +164,31 @@ type BinaryOperationExpr struct { } // Accept implements Node interface. -func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(o); skipChildren { - return o, ok +func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } + n = newNod.(*BinaryOperationExpr) - node, ok := o.L.Accept(v) + node, ok := n.L.Accept(v) if !ok { - return o, false + return n, false } - o.L = node.(ExprNode) + n.L = node.(ExprNode) - node, ok = o.R.Accept(v) + node, ok = n.R.Accept(v) if !ok { - return o, false + return n, false } - o.R = node.(ExprNode) + n.R = node.(ExprNode) - return v.Leave(o) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (o *BinaryOperationExpr) IsStatic() bool { - return o.L.IsStatic() && o.R.IsStatic() +func (n *BinaryOperationExpr) IsStatic() bool { + return n.L.IsStatic() && n.R.IsStatic() } // WhenClause is the when clause in Case expression for "when condition then result". @@ -156,27 +201,29 @@ type WhenClause struct { } // Accept implements Node Accept interface. -func (w *WhenClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(w); skipChildren { - return w, ok +func (n *WhenClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := w.Expr.Accept(v) + n = newNod.(*WhenClause) + node, ok := n.Expr.Accept(v) if !ok { - return w, false + return n, false } - w.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = w.Result.Accept(v) + node, ok = n.Result.Accept(v) if !ok { - return w, false + return n, false } - w.Result = node.(ExprNode) - return v.Leave(w) + n.Result = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (w *WhenClause) IsStatic() bool { - return w.Expr.IsStatic() && w.Result.IsStatic() +func (n *WhenClause) IsStatic() bool { + return n.Expr.IsStatic() && n.Result.IsStatic() } // CaseExpr is the case expression. @@ -191,45 +238,47 @@ type CaseExpr struct { } // Accept implements Node Accept interface. -func (f *CaseExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (n *CaseExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if f.Value != nil { - node, ok := f.Value.Accept(v) + n = newNod.(*CaseExpr) + if n.Value != nil { + node, ok := n.Value.Accept(v) if !ok { - return f, false + return n, false } - f.Value = node.(ExprNode) + n.Value = node.(ExprNode) } - for i, val := range f.WhenClauses { + for i, val := range n.WhenClauses { node, ok := val.Accept(v) if !ok { - return f, false + return n, false } - f.WhenClauses[i] = node.(*WhenClause) + n.WhenClauses[i] = node.(*WhenClause) } - if f.ElseClause != nil { - node, ok := f.ElseClause.Accept(v) + if n.ElseClause != nil { + node, ok := n.ElseClause.Accept(v) if !ok { - return f, false + return n, false } - f.ElseClause = node.(ExprNode) + n.ElseClause = node.(ExprNode) } - return v.Leave(f) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (f *CaseExpr) IsStatic() bool { - if f.Value != nil && !f.Value.IsStatic() { +func (n *CaseExpr) IsStatic() bool { + if n.Value != nil && !n.Value.IsStatic() { return false } - for _, w := range f.WhenClauses { + for _, w := range n.WhenClauses { if !w.IsStatic() { return false } } - if f.ElseClause != nil && !f.ElseClause.IsStatic() { + if n.ElseClause != nil && !n.ElseClause.IsStatic() { return false } return true @@ -239,30 +288,32 @@ func (f *CaseExpr) IsStatic() bool { type SubqueryExpr struct { exprNode // Query is the query SelectNode. - Query *SelectStmt + Query ResultSetNode } // Accept implements Node Accept interface. -func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sq); skipChildren { - return sq, ok +func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := sq.Query.Accept(v) + n = newNod.(*SubqueryExpr) + node, ok := n.Query.Accept(v) if !ok { - return sq, false + return n, false } - sq.Query = node.(*SelectStmt) - return v.Leave(sq) + n.Query = node.(ResultSetNode) + return v.Leave(n) } // SetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) SetResultFields(rfs []*ResultField) { - sq.Query.SetResultFields(rfs) +func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) { + n.Query.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) GetResultFields() []*ResultField { - return sq.Query.GetResultFields() +func (n *SubqueryExpr) GetResultFields() []*ResultField { + return n.Query.GetResultFields() } // CompareSubqueryExpr is the expression for "expr cmp (select ...)". @@ -282,21 +333,23 @@ type CompareSubqueryExpr struct { } // Accept implements Node Accept interface. -func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := cs.L.Accept(v) + n = newNod.(*CompareSubqueryExpr) + node, ok := n.L.Accept(v) if !ok { - return cs, false + return n, false } - cs.L = node.(ExprNode) - node, ok = cs.R.Accept(v) + n.L = node.(ExprNode) + node, ok = n.R.Accept(v) if !ok { - return cs, false + return n, false } - cs.R = node.(*SubqueryExpr) - return v.Leave(cs) + n.R = node.(*SubqueryExpr) + return v.Leave(n) } // ColumnName represents column name. @@ -312,11 +365,13 @@ type ColumnName struct { } // Accept implements Node Accept interface. -func (cn *ColumnName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cn); skipChildren { - return cn, ok +func (n *ColumnName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cn) + n = newNod.(*ColumnName) + return v.Leave(n) } // ColumnNameExpr represents a column name expression. @@ -328,16 +383,18 @@ type ColumnNameExpr struct { } // Accept implements Node Accept interface. -func (cr *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cr); skipChildren { - return cr, ok +func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := cr.Name.Accept(v) + n = newNod.(*ColumnNameExpr) + node, ok := n.Name.Accept(v) if !ok { - return cr, false + return n, false } - cr.Name = node.(*ColumnName) - return v.Leave(cr) + n.Name = node.(*ColumnName) + return v.Leave(n) } // DefaultExpr is the default expression using default value for a column. @@ -348,18 +405,20 @@ type DefaultExpr struct { } // Accept implements Node Accept interface. -func (d *DefaultExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(d); skipChildren { - return d, ok +func (n *DefaultExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if d.Name != nil { - node, ok := d.Name.Accept(v) + n = newNod.(*DefaultExpr) + if n.Name != nil { + node, ok := n.Name.Accept(v) if !ok { - return d, false + return n, false } - d.Name = node.(*ColumnName) + n.Name = node.(*ColumnName) } - return v.Leave(d) + return v.Leave(n) } // IdentifierExpr represents an identifier expression. @@ -370,11 +429,13 @@ type IdentifierExpr struct { } // Accept implements Node Accept interface. -func (i *IdentifierExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(i); skipChildren { - return i, ok +func (n *IdentifierExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(i) + n = newNod.(*IdentifierExpr) + return v.Leave(n) } // ExistsSubqueryExpr is the expression for "exists (select ...)". @@ -386,16 +447,18 @@ type ExistsSubqueryExpr struct { } // Accept implements Node Accept interface. -func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Sel.Accept(v) + n = newNod.(*ExistsSubqueryExpr) + node, ok := n.Sel.Accept(v) if !ok { - return es, false + return n, false } - es.Sel = node.(*SubqueryExpr) - return v.Leave(es) + n.Sel = node.(*SubqueryExpr) + return v.Leave(n) } // PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)". @@ -412,30 +475,32 @@ type PatternInExpr struct { } // Accept implements Node Accept interface. -func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pi); skipChildren { - return pi, ok +func (n *PatternInExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := pi.Expr.Accept(v) + n = newNod.(*PatternInExpr) + node, ok := n.Expr.Accept(v) if !ok { - return pi, false + return n, false } - pi.Expr = node.(ExprNode) - for i, val := range pi.List { + n.Expr = node.(ExprNode) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return pi, false + return n, false } - pi.List[i] = node.(ExprNode) + n.List[i] = node.(ExprNode) } - if pi.Sel != nil { - node, ok = pi.Sel.Accept(v) + if n.Sel != nil { + node, ok = n.Sel.Accept(v) if !ok { - return pi, false + return n, false } - pi.Sel = node.(*SubqueryExpr) + n.Sel = node.(*SubqueryExpr) } - return v.Leave(pi) + return v.Leave(n) } // IsNullExpr is the expression for null check. @@ -448,21 +513,23 @@ type IsNullExpr struct { } // Accept implements Node Accept interface. -func (is *IsNullExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (n *IsNullExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + n = newNod.(*IsNullExpr) + node, ok := n.Expr.Accept(v) if !ok { - return is, false + return n, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsNullExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (n *IsNullExpr) IsStatic() bool { + return n.Expr.IsStatic() } // IsTruthExpr is the expression for true/false check. @@ -477,21 +544,23 @@ type IsTruthExpr struct { } // Accept implements Node Accept interface. -func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + n = newNod.(*IsTruthExpr) + node, ok := n.Expr.Accept(v) if !ok { - return is, false + return n, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsTruthExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (n *IsTruthExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PatternLikeExpr is the expression for like operator, e.g, expr like "%123%" @@ -508,40 +577,49 @@ type PatternLikeExpr struct { } // Accept implements Node Accept interface. -func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pl); skipChildren { - return pl, ok +func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := pl.Expr.Accept(v) - if !ok { - return pl, false + n = newNod.(*PatternLikeExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) + if !ok { + return n, false + } + n.Expr = node.(ExprNode) } - pl.Expr = node.(ExprNode) - node, ok = pl.Pattern.Accept(v) - if !ok { - return pl, false + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) + if !ok { + return n, false + } + n.Pattern = node.(ExprNode) } - pl.Pattern = node.(ExprNode) - return v.Leave(pl) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (pl *PatternLikeExpr) IsStatic() bool { - return pl.Expr.IsStatic() && pl.Pattern.IsStatic() +func (n *PatternLikeExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // ParamMarkerExpr expresion holds a place for another expression. // Used in parsing prepare statement. type ParamMarkerExpr struct { exprNode + Offset int } // Accept implements Node Accept interface. -func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pm); skipChildren { - return pm, ok +func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(pm) + n = newNod.(*ParamMarkerExpr) + return v.Leave(n) } // ParenthesesExpr is the parentheses expression. @@ -552,23 +630,25 @@ type ParenthesesExpr struct { } // Accept implements Node Accept interface. -func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if p.Expr != nil { - node, ok := p.Expr.Accept(v) + n = newNod.(*ParenthesesExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return p, false + return n, false } - p.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(p) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (p *ParenthesesExpr) IsStatic() bool { - return p.Expr.IsStatic() +func (n *ParenthesesExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PositionExpr is the expression for order by and group by position. @@ -583,16 +663,18 @@ type PositionExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (p *PositionExpr) IsStatic() bool { +func (n *PositionExpr) IsStatic() bool { return true } // Accept implements Node Accept interface. -func (p *PositionExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (n *PositionExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(p) + n = newNod.(*PositionExpr) + return v.Leave(n) } // PatternRegexpExpr is the pattern expression for pattern match. @@ -607,26 +689,28 @@ type PatternRegexpExpr struct { } // Accept implements Node Accept interface. -func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := p.Expr.Accept(v) + n = newNod.(*PatternRegexpExpr) + node, ok := n.Expr.Accept(v) if !ok { - return p, false + return n, false } - p.Expr = node.(ExprNode) - node, ok = p.Pattern.Accept(v) + n.Expr = node.(ExprNode) + node, ok = n.Pattern.Accept(v) if !ok { - return p, false + return n, false } - p.Pattern = node.(ExprNode) - return v.Leave(p) + n.Pattern = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (p *PatternRegexpExpr) IsStatic() bool { - return p.Expr.IsStatic() && p.Pattern.IsStatic() +func (n *PatternRegexpExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // RowExpr is the expression for row constructor. @@ -638,23 +722,25 @@ type RowExpr struct { } // Accept implements Node Accept interface. -func (r *RowExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(r); skipChildren { - return r, ok +func (n *RowExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range r.Values { + n = newNod.(*RowExpr) + for i, val := range n.Values { node, ok := val.Accept(v) if !ok { - return r, false + return n, false } - r.Values[i] = node.(ExprNode) + n.Values[i] = node.(ExprNode) } - return v.Leave(r) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (r *RowExpr) IsStatic() bool { - for _, v := range r.Values { +func (n *RowExpr) IsStatic() bool { + for _, v := range n.Values { if !v.IsStatic() { return false } @@ -672,21 +758,23 @@ type UnaryOperationExpr struct { } // Accept implements Node Accept interface. -func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(u); skipChildren { - return u, ok +func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := u.V.Accept(v) + n = newNod.(*UnaryOperationExpr) + node, ok := n.V.Accept(v) if !ok { - return u, false + return n, false } - u.V = node.(ExprNode) - return v.Leave(u) + n.V = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (u *UnaryOperationExpr) IsStatic() bool { - return u.V.IsStatic() +func (n *UnaryOperationExpr) IsStatic() bool { + return n.V.IsStatic() } // ValuesExpr is the expression used in INSERT VALUES @@ -697,16 +785,18 @@ type ValuesExpr struct { } // Accept implements Node Accept interface. -func (va *ValuesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (n *ValuesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Column.Accept(v) + n = newNod.(*ValuesExpr) + node, ok := n.Column.Accept(v) if !ok { - return va, false + return n, false } - va.Column = node.(*ColumnName) - return v.Leave(va) + n.Column = node.(*ColumnName) + return v.Leave(n) } // VariableExpr is the expression for variable. @@ -721,9 +811,11 @@ type VariableExpr struct { } // Accept implements Node Accept interface. -func (va *VariableExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (n *VariableExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(va) + n = newNod.(*VariableExpr) + return v.Leave(n) } diff --git a/ast/functions.go b/ast/functions.go index 3843478313..7faab4c6e1 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -26,44 +26,49 @@ var ( _ FuncNode = &FuncConvertExpr{} _ FuncNode = &FuncCastExpr{} _ FuncNode = &FuncSubstringExpr{} + _ FuncNode = &FuncLocateExpr{} _ FuncNode = &FuncTrimExpr{} + _ FuncNode = &FuncDateArithExpr{} + _ FuncNode = &AggregateFuncExpr{} ) +// UnquoteString is not quoted when printed. +type UnquoteString string + // FuncCallExpr is for function expression. type FuncCallExpr struct { funcNode // F is the function name. - F string + FnName string // Args is the function args. Args []ExprNode - // Distinct only affetcts sum, avg, count, group_concat, - // so we can ignore it in other functions - Distinct bool } // Accept implements Node interface. -func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(c); skipChildren { - return c, ok +func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range c.Args { + n = newNod.(*FuncCallExpr) + for i, val := range n.Args { node, ok := val.Accept(v) if !ok { - return c, false + return n, false } - c.Args[i] = node.(ExprNode) + n.Args[i] = node.(ExprNode) } - return v.Leave(c) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (c *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(c.F)] +func (n *FuncCallExpr) IsStatic() bool { + v := builtin.Funcs[strings.ToLower(n.FnName)] if v.F == nil || !v.IsStatic { return false } - for _, v := range c.Args { + for _, v := range n.Args { if !v.IsStatic() { return false } @@ -81,21 +86,23 @@ type FuncExtractExpr struct { } // Accept implements Node Accept interface. -func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ex); skipChildren { - return ex, ok +func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ex.Date.Accept(v) + n = newNod.(*FuncExtractExpr) + node, ok := n.Date.Accept(v) if !ok { - return ex, false + return n, false } - ex.Date = node.(ExprNode) - return v.Leave(ex) + n.Date = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (ex *FuncExtractExpr) IsStatic() bool { - return ex.Date.IsStatic() +func (n *FuncExtractExpr) IsStatic() bool { + return n.Date.IsStatic() } // FuncConvertExpr provides a way to convert data between different character sets. @@ -109,21 +116,23 @@ type FuncConvertExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncConvertExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (n *FuncConvertExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + n = newNod.(*FuncConvertExpr) + node, ok := n.Expr.Accept(v) if !ok { - return f, false + return n, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + n.Expr = node.(ExprNode) + return v.Leave(n) } // CastFunctionType is the type for cast function. @@ -149,21 +158,23 @@ type FuncCastExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncCastExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (n *FuncCastExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + n = newNod.(*FuncCastExpr) + node, ok := n.Expr.Accept(v) if !ok { - return f, false + return n, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + n.Expr = node.(ExprNode) + return v.Leave(n) } // FuncSubstringExpr returns the substring as specified. @@ -177,31 +188,35 @@ type FuncSubstringExpr struct { } // Accept implements Node Accept interface. -func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := sf.StrExpr.Accept(v) + n = newNod.(*FuncSubstringExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return sf, false + return n, false } - sf.StrExpr = node.(ExprNode) - node, ok = sf.Pos.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return sf, false + return n, false } - sf.Pos = node.(ExprNode) - node, ok = sf.Len.Accept(v) - if !ok { - return sf, false + n.Pos = node.(ExprNode) + if n.Len != nil { + node, ok = n.Len.Accept(v) + if !ok { + return n, false + } + n.Len = node.(ExprNode) } - sf.Len = node.(ExprNode) - return v.Leave(sf) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (sf *FuncSubstringExpr) IsStatic() bool { - return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic() +func (n *FuncSubstringExpr) IsStatic() bool { + return n.StrExpr.IsStatic() && n.Pos.IsStatic() && n.Len.IsStatic() } // FuncSubstringIndexExpr returns the substring as specified. @@ -215,26 +230,28 @@ type FuncSubstringIndexExpr struct { } // Accept implements Node Accept interface. -func (si *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(si); skipChildren { - return si, ok +func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := si.StrExpr.Accept(v) + n = newNod.(*FuncSubstringIndexExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return si, false + return n, false } - si.StrExpr = node.(ExprNode) - node, ok = si.Delim.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Delim.Accept(v) if !ok { - return si, false + return n, false } - si.Delim = node.(ExprNode) - node, ok = si.Count.Accept(v) + n.Delim = node.(ExprNode) + node, ok = n.Count.Accept(v) if !ok { - return si, false + return n, false } - si.Count = node.(ExprNode) - return v.Leave(si) + n.Count = node.(ExprNode) + return v.Leave(n) } // FuncLocateExpr returns the position of the first occurrence of substring. @@ -248,26 +265,28 @@ type FuncLocateExpr struct { } // Accept implements Node Accept interface. -func (le *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(le); skipChildren { - return le, ok +func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := le.Str.Accept(v) + n = newNod.(*FuncLocateExpr) + node, ok := n.Str.Accept(v) if !ok { - return le, false + return n, false } - le.Str = node.(ExprNode) - node, ok = le.SubStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.SubStr.Accept(v) if !ok { - return le, false + return n, false } - le.SubStr = node.(ExprNode) - node, ok = le.Pos.Accept(v) + n.SubStr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return le, false + return n, false } - le.Pos = node.(ExprNode) - return v.Leave(le) + n.Pos = node.(ExprNode) + return v.Leave(n) } // TrimDirectionType is the type for trim direction. @@ -295,27 +314,102 @@ type FuncTrimExpr struct { } // Accept implements Node Accept interface. -func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tf); skipChildren { - return tf, ok +func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := tf.Str.Accept(v) + n = newNod.(*FuncTrimExpr) + node, ok := n.Str.Accept(v) if !ok { - return tf, false + return n, false } - tf.Str = node.(ExprNode) - node, ok = tf.RemStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.RemStr.Accept(v) if !ok { - return tf, false + return n, false } - tf.RemStr = node.(ExprNode) - return v.Leave(tf) + n.RemStr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (tf *FuncTrimExpr) IsStatic() bool { - return tf.Str.IsStatic() && tf.RemStr.IsStatic() +func (n *FuncTrimExpr) IsStatic() bool { + return n.Str.IsStatic() && n.RemStr.IsStatic() } -// TypeStar is a special type for "*". -type TypeStar string +// DateArithType is type for DateArith option. +type DateArithType byte + +const ( + // DateAdd is to run date_add function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add + DateAdd DateArithType = iota + 1 + // DateSub is to run date_sub function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub + DateSub +) + +// FuncDateArithExpr is the struct for date arithmetic functions. +type FuncDateArithExpr struct { + funcNode + + Op DateArithType + Unit string + Date ExprNode + Interval ExprNode +} + +// Accept implements Node Accept interface. +func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*FuncDateArithExpr) + if n.Date != nil { + node, ok := n.Date.Accept(v) + if !ok { + return n, false + } + n.Date = node.(ExprNode) + } + if n.Interval != nil { + node, ok := n.Date.Accept(v) + if !ok { + return n, false + } + n.Date = node.(ExprNode) + } + return v.Leave(n) +} + +// AggregateFuncExpr represents aggregate function expression. +type AggregateFuncExpr struct { + funcNode + // F is the function name. + F string + // Args is the function args. + Args []ExprNode + // If distinct is true, the function only aggregate distinct values. + // For example, column c1 values are "1", "2", "2", "sum(c1)" is "5", + // but "sum(distinct c1)" is "3". + Distinct bool +} + +// Accept implements Node Accept interface. +func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*AggregateFuncExpr) + for i, val := range n.Args { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Args[i] = node.(ExprNode) + } + return v.Leave(n) +} diff --git a/ast/misc.go b/ast/misc.go index 2ad94e7755..4f287b4680 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -13,6 +13,8 @@ package ast +import "github.com/pingcap/tidb/mysql" + var ( _ StmtNode = &ExplainStmt{} _ StmtNode = &PrepareStmt{} @@ -26,7 +28,9 @@ var ( _ StmtNode = &SetStmt{} _ StmtNode = &SetCharsetStmt{} _ StmtNode = &SetPwdStmt{} + _ StmtNode = &CreateUserStmt{} _ StmtNode = &DoStmt{} + _ StmtNode = &GrantStmt{} _ Node = &VariableAssignment{} ) @@ -57,16 +61,18 @@ type ExplainStmt struct { } // Accept implements Node Accept interface. -func (es *ExplainStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Stmt.Accept(v) + n = newNod.(*ExplainStmt) + node, ok := n.Stmt.Accept(v) if !ok { - return es, false + return n, false } - es.Stmt = node.(DMLNode) - return v.Leave(es) + n.Stmt = node.(DMLNode) + return v.Leave(n) } // PrepareStmt is a statement to prepares a SQL statement which contains placeholders, @@ -83,16 +89,18 @@ type PrepareStmt struct { } // Accept implements Node Accept interface. -func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ps); skipChildren { - return ps, ok +func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ps.SQLVar.Accept(v) + n = newNod.(*PrepareStmt) + node, ok := n.SQLVar.Accept(v) if !ok { - return ps, false + return n, false } - ps.SQLVar = node.(*VariableExpr) - return v.Leave(ps) + n.SQLVar = node.(*VariableExpr) + return v.Leave(n) } // DeallocateStmt is a statement to release PreparedStmt. @@ -105,11 +113,13 @@ type DeallocateStmt struct { } // Accept implements Node Accept interface. -func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ds); skipChildren { - return ds, ok +func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(ds) + n = newNod.(*DeallocateStmt) + return v.Leave(n) } // ExecuteStmt is a statement to execute PreparedStmt. @@ -123,18 +133,20 @@ type ExecuteStmt struct { } // Accept implements Node Accept interface. -func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range es.UsingVars { + n = newNod.(*ExecuteStmt) + for i, val := range n.UsingVars { node, ok := val.Accept(v) if !ok { - return es, false + return n, false } - es.UsingVars[i] = node.(ExprNode) + n.UsingVars[i] = node.(ExprNode) } - return v.Leave(es) + return v.Leave(n) } // ShowStmtType is the type for SHOW statement. @@ -152,12 +164,13 @@ const ( ShowVariables ShowCollation ShowCreateTable + ShowGrants ) // ShowStmt is a statement to provide information about databases, tables, columns and so on. // See: https://dev.mysql.com/doc/refman/5.7/en/show.html type ShowStmt struct { - stmtNode + dmlNode Tp ShowStmtType // Databases/Tables/Columns/.... DBName string @@ -165,6 +178,7 @@ type ShowStmt struct { Column *ColumnName // Used for `desc table column`. Flag int // Some flag parsed from sql, such as FULL. Full bool + User string // Used for show grants. // Used by show variables GlobalScope bool @@ -173,39 +187,41 @@ type ShowStmt struct { } // Accept implements Node Accept interface. -func (ss *ShowStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ss); skipChildren { - return ss, ok +func (n *ShowStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if ss.Table != nil { - node, ok := ss.Table.Accept(v) + n = newNod.(*ShowStmt) + if n.Table != nil { + node, ok := n.Table.Accept(v) if !ok { - return ss, false + return n, false } - ss.Table = node.(*TableName) + n.Table = node.(*TableName) } - if ss.Column != nil { - node, ok := ss.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return ss, false + return n, false } - ss.Column = node.(*ColumnName) + n.Column = node.(*ColumnName) } - if ss.Pattern != nil { - node, ok := ss.Pattern.Accept(v) + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) if !ok { - return ss, false + return n, false } - ss.Pattern = node.(*PatternLikeExpr) + n.Pattern = node.(*PatternLikeExpr) } - if ss.Where != nil { - node, ok := ss.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return ss, false + return n, false } - ss.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - return v.Leave(ss) + return v.Leave(n) } // BeginStmt is a statement to start a new transaction. @@ -215,11 +231,13 @@ type BeginStmt struct { } // Accept implements Node Accept interface. -func (bs *BeginStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(bs); skipChildren { - return bs, ok +func (n *BeginStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(bs) + n = newNod.(*BeginStmt) + return v.Leave(n) } // CommitStmt is a statement to commit the current transaction. @@ -229,11 +247,13 @@ type CommitStmt struct { } // Accept implements Node Accept interface. -func (cs *CommitStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (n *CommitStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cs) + n = newNod.(*CommitStmt) + return v.Leave(n) } // RollbackStmt is a statement to roll back the current transaction. @@ -243,11 +263,13 @@ type RollbackStmt struct { } // Accept implements Node Accept interface. -func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rs); skipChildren { - return rs, ok +func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(rs) + n = newNod.(*RollbackStmt) + return v.Leave(n) } // UseStmt is a statement to use the DBName database as the current database. @@ -259,11 +281,13 @@ type UseStmt struct { } // Accept implements Node Accept interface. -func (us *UseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(us); skipChildren { - return us, ok +func (n *UseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(us) + n = newNod.(*UseStmt) + return v.Leave(n) } // VariableAssignment is a variable assignment struct. @@ -276,16 +300,18 @@ type VariableAssignment struct { } // Accept implements Node interface. -func (va *VariableAssignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Value.Accept(v) + n = newNod.(*VariableAssignment) + node, ok := n.Value.Accept(v) if !ok { - return va, false + return n, false } - va.Value = node.(ExprNode) - return v.Leave(va) + n.Value = node.(ExprNode) + return v.Leave(n) } // SetStmt is the statement to set variables. @@ -296,18 +322,20 @@ type SetStmt struct { } // Accept implements Node Accept interface. -func (set *SetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (n *SetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range set.Variables { + n = newNod.(*SetStmt) + for i, val := range n.Variables { node, ok := val.Accept(v) if !ok { - return set, false + return n, false } - set.Variables[i] = node.(*VariableAssignment) + n.Variables[i] = node.(*VariableAssignment) } - return v.Leave(set) + return v.Leave(n) } // SetCharsetStmt is a statement to assign values to character and collation variables. @@ -320,11 +348,13 @@ type SetCharsetStmt struct { } // Accept implements Node Accept interface. -func (set *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + n = newNod.(*SetCharsetStmt) + return v.Leave(n) } // SetPwdStmt is a statement to assign a password to user account. @@ -337,11 +367,13 @@ type SetPwdStmt struct { } // Accept implements Node Accept interface. -func (set *SetPwdStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + n = newNod.(*SetPwdStmt) + return v.Leave(n) } // UserSpec is used for parsing create user statement. @@ -353,10 +385,20 @@ type UserSpec struct { // CreateUserStmt creates user account. // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html type CreateUserStmt struct { + stmtNode + IfNotExists bool Specs []*UserSpec +} - Text string +// Accept implements Node Accept interface. +func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*CreateUserStmt) + return v.Leave(n) } // DoStmt is the struct for DO statement. @@ -367,16 +409,100 @@ type DoStmt struct { } // Accept implements Node Accept interface. -func (do *DoStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(do); skipChildren { - return do, ok +func (n *DoStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range do.Exprs { + n = newNod.(*DoStmt) + for i, val := range n.Exprs { node, ok := val.Accept(v) if !ok { - return do, false + return n, false } - do.Exprs[i] = node.(ExprNode) + n.Exprs[i] = node.(ExprNode) } - return v.Leave(do) + return v.Leave(n) +} + +// PrivElem is the privilege type and optional column list. +type PrivElem struct { + node + Priv mysql.PrivilegeType + Cols []*ColumnName +} + +// Accept implements Node Accept interface. +func (n *PrivElem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*PrivElem) + for i, val := range n.Cols { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Cols[i] = node.(*ColumnName) + } + return v.Leave(n) +} + +// ObjectTypeType is the type for object type. +type ObjectTypeType int + +const ( + // ObjectTypeNone is for empty object type. + ObjectTypeNone ObjectTypeType = iota + // ObjectTypeTable means the following object is a table. + ObjectTypeTable +) + +// GrantLevelType is the type for grant level. +type GrantLevelType int + +const ( + // GrantLevelNone is the dummy const for default value. + GrantLevelNone GrantLevelType = iota + // GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server. + GrantLevelGlobal + // GrantLevelDB means the privileges apply to all objects in a given database. + GrantLevelDB + // GrantLevelTable means the privileges apply to all columns in a given table. + GrantLevelTable +) + +// GrantLevel is used for store the privilege scope. +type GrantLevel struct { + Level GrantLevelType + DBName string + TableName string +} + +// GrantStmt is the struct for GRANT statement. +type GrantStmt struct { + stmtNode + + Privs []*PrivElem + ObjectType ObjectTypeType + Level *GrantLevel + Users []*UserSpec +} + +// Accept implements Node Accept interface. +func (n *GrantStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*GrantStmt) + for i, val := range n.Privs { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Privs[i] = node.(*PrivElem) + } + return v.Leave(n) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y deleted file mode 100644 index 24cde087e6..0000000000 --- a/ast/parser/parser.y +++ /dev/null @@ -1,3917 +0,0 @@ -%{ -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -// Inital yacc source generated by ebnf2y[1] -// at 2013-10-04 23:10:47.861401015 +0200 CEST -// -// $ ebnf2y -o ql.y -oe ql.ebnf -start StatementList -pkg ql -p _ -// -// [1]: http://github.com/cznic/ebnf2y - -package parser - -import ( - "strings" - - "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/field" - "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/util/charset" - "github.com/pingcap/tidb/util/types" -) - -%} - -%union { - offset int // offset - line int - col int - item interface{} - list []interface{} -} - -%token - - /*yy:token "1.%d" */ floatLit "floating-point literal" - /*yy:token "%c" */ identifier "identifier" - /*yy:token "%d" */ intLit "integer literal" - /*yy:token "\"%c\"" */ stringLit "string literal" - /*yy:token "%x" */ hexLit "hexadecimal literal" - /*yy:token "%b" */ bitLit "bit literal" - - - abs "ABS" - add "ADD" - after "AFTER" - all "ALL" - alter "ALTER" - and "AND" - andand "&&" - andnot "&^" - any "ANY" - as "AS" - asc "ASC" - at "AT" - autoIncrement "AUTO_INCREMENT" - avg "AVG" - avgRowLength "AVG_ROW_LENGTH" - begin "BEGIN" - between "BETWEEN" - both "BOTH" - by "BY" - byteType "BYTE" - caseKwd "CASE" - cast "CAST" - character "CHARACTER" - charsetKwd "CHARSET" - check "CHECK" - checksum "CHECKSUM" - coalesce "COALESCE" - collate "COLLATE" - collation "COLLATION" - column "COLUMN" - columns "COLUMNS" - comment "COMMENT" - commit "COMMIT" - compression "COMPRESSION" - concat "CONCAT" - concatWs "CONCAT_WS" - connection "CONNECTION" - constraint "CONSTRAINT" - convert "CONVERT" - count "COUNT" - create "CREATE" - cross "CROSS" - curDate "CURDATE" - currentDate "CURRENT_DATE" - currentUser "CURRENT_USER" - database "DATABASE" - databases "DATABASES" - day "DAY" - dayofmonth "DAYOFMONTH" - dayofweek "DAYOFWEEK" - dayofyear "DAYOFYEAR" - deallocate "DEALLOCATE" - defaultKwd "DEFAULT" - delayed "DELAYED" - deleteKwd "DELETE" - desc "DESC" - describe "DESCRIBE" - distinct "DISTINCT" - div "DIV" - do "DO" - drop "DROP" - dual "DUAL" - duplicate "DUPLICATE" - elseKwd "ELSE" - end "END" - engine "ENGINE" - engines "ENGINES" - enum "ENUM" - eq "=" - escape "ESCAPE" - execute "EXECUTE" - exists "EXISTS" - explain "EXPLAIN" - extract "EXTRACT" - falseKwd "false" - first "FIRST" - foreign "FOREIGN" - forKwd "FOR" - foundRows "FOUND_ROWS" - from "FROM" - full "FULL" - fulltext "FULLTEXT" - ge ">=" - global "GLOBAL" - group "GROUP" - groupConcat "GROUP_CONCAT" - having "HAVING" - highPriority "HIGH_PRIORITY" - hour "HOUR" - identified "IDENTIFIED" - ignore "IGNORE" - ifKwd "IF" - ifNull "IFNULL" - in "IN" - index "INDEX" - inner "INNER" - insert "INSERT" - into "INTO" - is "IS" - join "JOIN" - key "KEY" - keyBlockSize "KEY_BLOCK_SIZE" - le "<=" - leading "LEADING" - left "LEFT" - length "LENGTH" - like "LIKE" - limit "LIMIT" - local "LOCAL" - locate "LOCATE" - lock "LOCK" - lower "LOWER" - lowPriority "LOW_PRIORITY" - lsh "<<" - max "MAX" - maxRows "MAX_ROWS" - microsecond "MICROSECOND" - min "MIN" - minute "MINUTE" - minRows "MIN_ROWS" - mod "MOD" - mode "MODE" - month "MONTH" - names "NAMES" - national "NATIONAL" - neq "!=" - neqSynonym "<>" - not "NOT" - null "NULL" - nulleq "<=>" - nullIf "NULLIF" - offset "OFFSET" - on "ON" - or "OR" - order "ORDER" - oror "||" - outer "OUTER" - password "PASSWORD" - placeholder "PLACEHOLDER" - prepare "PREPARE" - primary "PRIMARY" - quarter "QUARTER" - quick "QUICK" - rand "RAND" - references "REFERENCES" - regexp "REGEXP" - repeat "REPEAT" - right "RIGHT" - rlike "RLIKE" - rollback "ROLLBACK" - row "ROW" - rsh ">>" - schema "SCHEMA" - schemas "SCHEMAS" - second "SECOND" - selectKwd "SELECT" - session "SESSION" - set "SET" - share "SHARE" - show "SHOW" - signed "SIGNED" - some "SOME" - start "START" - stringType "string" - substring "SUBSTRING" - substringIndex "SUBSTRING_INDEX" - sum "SUM" - sysVar "SYS_VAR" - sysDate "SYSDATE" - tableKwd "TABLE" - tables "TABLES" - then "THEN" - trailing "TRAILING" - transaction "TRANSACTION" - trim "TRIM" - trueKwd "true" - truncate "TRUNCATE" - unknown "UNKNOWN" - union "UNION" - unique "UNIQUE" - unsigned "UNSIGNED" - update "UPDATE" - upper "UPPER" - use "USE" - user "USER" - using "USING" - userVar "USER_VAR" - value "VALUE" - values "VALUES" - variables "VARIABLES" - warnings "WARNINGS" - week "WEEK" - weekday "WEEKDAY" - weekofyear "WEEKOFYEAR" - when "WHEN" - where "WHERE" - xor "XOR" - yearweek "YEARWEEK" - zerofill "ZEROFILL" - - calcFoundRows "SQL_CALC_FOUND_ROWS" - - currentTs "CURRENT_TIMESTAMP" - localTime "LOCALTIME" - localTs "LOCALTIMESTAMP" - now "NOW" - - tinyIntType "TINYINT" - smallIntType "SMALLINT" - mediumIntType "MEDIUMINT" - intType "INT" - integerType "INTEGER" - bigIntType "BIGINT" - bitType "BIT" - - decimalType "DECIMAL" - numericType "NUMERIC" - floatType "float" - doubleType "DOUBLE" - precisionType "PRECISION" - realType "REAL" - - dateType "DATE" - timeType "TIME" - datetimeType "DATETIME" - timestampType "TIMESTAMP" - yearType "YEAR" - - charType "CHAR" - varcharType "VARCHAR" - binaryType "BINARY" - varbinaryType "VARBINARY" - tinyblobType "TINYBLOB" - blobType "BLOB" - mediumblobType "MEDIUMBLOB" - longblobType "LONGBLOB" - tinytextType "TINYTEXT" - textType "TEXT" - mediumtextType "MEDIUMTEXT" - longtextType "LONGTEXT" - - int16Type "int16" - int24Type "int24" - int32Type "int32" - int64Type "int64" - int8Type "int8" - uintType "uint" - uint16Type "uint16" - uint32Type "uint32" - uint64Type "uint64" - uint8Type "uint8", - float32Type "float32" - float64Type "float64" - boolType "BOOL" - booleanType "BOOLEAN" - - parseExpression "parse expression prefix" - - secondMicrosecond "SECOND_MICROSECOND" - minuteMicrosecond "MINUTE_MICROSECOND" - minuteSecond "MINUTE_SECOND" - hourMicrosecond "HOUR_MICROSECOND" - hourSecond "HOUR_SECOND" - hourMinute "HOUR_MINUTE" - dayMicrosecond "DAY_MICROSECOND" - daySecond "DAY_SECOND" - dayMinute "DAY_MINUTE" - dayHour "DAY_HOUR" - yearMonth "YEAR_MONTH" - -%type - AlterTableStmt "Alter table statement" - AlterTableSpec "Alter table specification" - AlterTableSpecList "Alter table specification list" - AnyOrAll "Any or All for subquery" - Assignment "assignment" - AssignmentList "assignment list" - AssignmentListOpt "assignment list opt" - AuthOption "User auth option" - AuthString "Password string value" - BeginTransactionStmt "BEGIN TRANSACTION statement" - CastType "Cast function target type" - ColumnDef "table column definition" - ColumnName "column name" - ColumnNameList "column name list" - ColumnNameListOpt "column name list opt" - ColumnKeywordOpt "Column keyword or empty" - ColumnSetValue "insert statement set value by column name" - ColumnSetValueList "insert statement set value by column name list" - CommaOpt "optional comma" - CommitStmt "COMMIT statement" - CompareOp "Compare opcode" - ColumnOption "column definition option" - ColumnOptionList "column definition option list" - ColumnOptionListOpt "optional column definition option list" - Constraint "table constraint" - ConstraintElem "table constraint element" - ConstraintKeywordOpt "Constraint Keyword or empty" - CreateDatabaseStmt "Create Database Statement" - CreateIndexStmt "CREATE INDEX statement" - CreateIndexStmtUnique "CREATE INDEX optional UNIQUE clause" - DatabaseOption "CREATE Database specification" - DatabaseOptionList "CREATE Database specification list" - DatabaseOptionListOpt "CREATE Database specification list opt" - CreateTableStmt "CREATE TABLE statement" - CreateUserStmt "CREATE User statement" - CrossOpt "Cross join option" - DatabaseSym "DATABASE or SCHEMA" - DBName "Database Name" - DeallocateSym "Deallocate or drop" - DeallocateStmt "Deallocate prepared statement" - Default "DEFAULT clause" - DefaultOpt "optional DEFAULT clause" - DefaultKwdOpt "optional DEFAULT keyword" - DefaultValueExpr "DefaultValueExpr(Now or Signed Literal)" - DeleteFromStmt "DELETE FROM statement" - DistinctOpt "Distinct option" - DoStmt "Do statement" - DropDatabaseStmt "DROP DATABASE statement" - DropIndexStmt "DROP INDEX statement" - DropTableStmt "DROP TABLE statement" - EmptyStmt "empty statement" - EqOpt "= or empty" - EscapedTableRef "escaped table reference" - ExecuteStmt "Execute statement" - ExplainSym "EXPLAIN or DESCRIBE or DESC" - ExplainStmt "EXPLAIN statement" - Expression "expression" - ExpressionList "expression list" - ExpressionListOpt "expression list opt" - ExpressionListList "expression list list" - Factor "expression factor" - PredicateExpr "Predicate expression factor" - Field "field expression" - FieldAsName "Field alias name" - FieldAsNameOpt "Field alias name opt" - FieldList "field expression list" - FromClause "From clause" - Function "function expr" - FunctionCallAgg "Function call on aggregate data" - FunctionCallConflict "Function call with reserved keyword as function name" - FunctionCallKeyword "Function call with keyword as function name" - FunctionCallNonKeyword "Function call with nonkeyword as function name" - FunctionNameConflict "Built-in function call names which are conflict with keywords" - FuncDatetimePrec "Function datetime precision" - GlobalScope "The scope of variable" - GroupByClause "GROUP BY clause" - GroupByList "GROUP BY list" - HashString "Hashed string" - HavingClause "HAVING clause" - IfExists "If Exists" - IfNotExists "If Not Exists" - IgnoreOptional "IGNORE or empty" - IndexColName "Index column name" - IndexColNameList "List of index column name" - IndexName "index name" - IndexType "index type" - InsertIntoStmt "INSERT INTO statement" - InsertRest "Rest part of INSERT INTO statement" - IntoOpt "INTO or EmptyString" - JoinTable "join table" - JoinType "join type" - KeyOrIndex "{KEY|INDEX}" - LikeEscapeOpt "like escape option" - LimitClause "LIMIT clause" - Literal "literal value" - logAnd "logical and operator" - logOr "logical or operator" - LowPriorityOptional "LOW_PRIORITY or empty" - name "name" - NationalOpt "National option" - NotOpt "optional NOT" - NowSym "CURRENT_TIMESTAMP/LOCALTIME/LOCALTIMESTAMP/NOW" - NumLiteral "Num/Int/Float/Decimal Literal" - OnDuplicateKeyUpdate "ON DUPLICATE KEY UPDATE value list" - Operand "operand" - OptFull "Full or empty" - OptInteger "Optional Integer keyword" - Order "ORDER BY clause optional collation specification" - OrderBy "ORDER BY clause" - OrderByItem "ORDER BY list item" - OrderByOptional "Optional ORDER BY clause optional" - OrderByList "ORDER BY list" - OuterOpt "optional OUTER clause" - QuickOptional "QUICK or empty" - PasswordOpt "Password option" - ColumnPosition "Column position [First|After ColumnName]" - PreparedStmt "PreparedStmt" - PrepareSQL "Prepare statement sql string" - PrimaryExpression "primary expression" - PrimaryFactor "primary expression factor" - Priority "insert statement priority" - ReferDef "Reference definition" - RegexpSym "REGEXP or RLIKE" - RollbackStmt "ROLLBACK statement" - SelectBasic "Basic SELECT statement without parentheses and union" - SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," - SelectStmt "SELECT statement" - SelectStmtCalcFoundRows "SELECT statement optional SQL_CALC_FOUND_ROWS" - SelectStmtDistinct "SELECT statement optional DISTINCT clause" - SelectStmtFieldList "SELECT statement field list" - SelectStmtLimit "SELECT statement optional LIMIT clause" - SelectStmtOpts "Select statement options" - SelectStmtGroup "SELECT statement optional GROUP BY clause" - SetStmt "Set variable statement" - ShowStmt "Show engines/databases/tables/columns/warnings statement" - ShowDatabaseNameOpt "Show tables/columns statement database name option" - ShowTableAliasOpt "Show table alias option" - ShowLikeOrWhereOpt "Show like or where condition option" - SignedLiteral "Literal or NumLiteral with sign" - Statement "statement" - StatementList "statement list" - StringName "string literal or identifier" - StringList "string list" - ExplainableStmt "explainable statement" - SubSelect "Sub Select" - Symbol "Constraint Symbol" - SystemVariable "System defined variable name" - TableAsName "table alias name" - TableAsNameOpt "table alias name optional" - TableElement "table definition element" - TableElementList "table definition element list" - TableFactor "table factor" - TableName "Table name" - TableNameList "Table name list" - TableOption "create table option" - TableOptionList "create table option list" - TableOptionListOpt "create table option list opt" - TableRef "table reference" - TableRefs "table references" - TimeUnit "Time unit" - TrimDirection "Trim string direction" - TruncateTableStmt "TRANSACTION TABLE statement" - UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionClause "Union select" - UnionClauseList "Union select clause list" - UnionClauseP "Union (select)" - UnionClausePList "Union (select) clause list" - UpdateStmt "UPDATE statement" - Username "Username" - UserSpec "Username and auth option" - UserSpecList "Username and auth option list" - UserVariable "User defined variable name" - UserVariableList "User defined variable name list" - UseStmt "USE statement" - ValueSym "Value or Values" - VariableAssignment "set variable value" - VariableAssignmentList "set variable value list" - Variable "User or system variable" - WhereClause "WHERE clause" - WhereClauseOptional "Optinal WHERE clause" - - Identifier "identifier or unreserved keyword" - UnReservedKeyword "MySQL unreserved keywords" - NotKeywordToken "Tokens not mysql keyword but treated specially" - - WhenClause "When clause" - WhenClauseList "When clause list" - ElseOpt "Optional else clause" - ExpressionOpt "Optional expression" - - Type "Types" - - NumericType "Numeric types" - IntegerType "Integer Types types" - FixedPointType "Exact value types" - FloatingPointType "Approximate value types" - BitValueType "bit value types" - - StringType "String types" - BlobType "Blob types" - TextType "Text types" - - DateAndTimeType "Date and Time types" - - OptFieldLen "Field length or empty" - FieldLen "Field length" - FieldOpts "Field type definition option list" - FieldOpt "Field type definition option" - FloatOpt "Floating-point type option" - Precision "Floating-point precision option" - OptBinary "Optional BINARY" - CharsetKw "charset or charater set" - OptCharset "Optional Character setting" - OptCollate "Optional Collate setting" - NUM "numbers" - LengthNum "Field length num(uint64)" - -%token tableRefPriority - -%precedence lowerThanCalcFoundRows -%precedence calcFoundRows - -%precedence lowerThanSetKeyword -%precedence set - -%precedence lowerThanInsertValues -%precedence insertValues - -%left join inner cross left right full -/* A dummy token to force the priority of TableRef production in a join. */ -%left tableRefPriority -%precedence on -%left oror or -%left xor -%left andand and -%left between -%precedence lowerThanEq -%left eq ge le neq neqSynonym '>' '<' is like in -%left '|' -%left '&' -%left rsh lsh -%left '-' '+' -%left '*' '/' '%' div mod -%left '^' -%left '~' neg -%right not -%right collate - -%precedence lowerThanLeftParen -%precedence '(' -%precedence lowerThanQuick -%precedence quick -%precedence lowerThanEscape -%precedence escape -%precedence lowerThanComma -%precedence ',' - -%start Start - -%% - -Start: - StatementList -| parseExpression Expression - { - yylex.(*lexer).expr = $2.(ast.ExprNode) - } - -/**************************************AlterTableStmt*************************************** - * See: https://dev.mysql.com/doc/refman/5.7/en/alter-table.html - *******************************************************************************************/ -AlterTableStmt: - "ALTER" IgnoreOptional "TABLE" TableName AlterTableSpecList - { - $$ = &ast.AlterTableStmt{ - Table: $4.(*ast.TableName), - Specs: $5.([]*ast.AlterTableSpec), - } - } - -AlterTableSpec: - TableOptionListOpt - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableOption, - Options:$1.([]*ast.TableOption), - } - } -| "ADD" ColumnKeywordOpt ColumnDef ColumnPosition - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableAddColumn, - Column: $3.(*ast.ColumnDef), - Position: $4.(*ast.ColumnPosition), - } - } -| "ADD" Constraint - { - constraint := $2.(*ast.Constraint) - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableAddConstraint, - Constraint: constraint, - } - } -| "DROP" ColumnKeywordOpt ColumnName - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableDropColumn, - ColumnName: $3.(*ast.ColumnName), - } - } -| "DROP" "PRIMARY" "KEY" - { - $$ = &ast.AlterTableSpec{Tp: ast.AlterTableDropPrimaryKey} - } -| "DROP" KeyOrIndex IndexName - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableDropIndex, - Name: $3.(string), - } - } -| "DROP" "FOREIGN" "KEY" Symbol - { - $$ = &ast.AlterTableSpec{ - Tp: ast.AlterTableDropForeignKey, - Name: $4.(string), - } - } - -KeyOrIndex: - "KEY"|"INDEX" - -ColumnKeywordOpt: - {} -| "COLUMN" - -ColumnPosition: - { - $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionNone} - } -| "FIRST" - { - $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionFirst} - } -| "AFTER" ColumnName - { - $$ = &ast.ColumnPosition{ - Tp: ast.ColumnPositionAfter, - RelativeColumn: $2.(*ast.ColumnName), - } - } - -AlterTableSpecList: - AlterTableSpec - { - $$ = []*ast.AlterTableSpec{$1.(*ast.AlterTableSpec)} - } -| AlterTableSpecList ',' AlterTableSpec - { - $$ = append($1.([]*ast.AlterTableSpec), $3.(*ast.AlterTableSpec)) - } - -ConstraintKeywordOpt: - { - $$ = nil - } -| "CONSTRAINT" - { - $$ = nil - } -| "CONSTRAINT" Symbol - { - $$ = $2.(string) - } - -Symbol: - Identifier - -/*******************************************************************************************/ -Assignment: - ColumnName eq Expression - { - $$ = &ast.Assignment{Column: $1.(*ast.ColumnName), Expr:$3.(ast.ExprNode)} - } - -AssignmentList: - Assignment - { - $$ = []*ast.Assignment{$1.(*ast.Assignment)} - } -| AssignmentList ',' Assignment - { - $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) - } - -AssignmentListOpt: - /* EMPTY */ - { - $$ = []*ast.Assignment{} - } -| AssignmentList - -BeginTransactionStmt: - "BEGIN" - { - $$ = &ast.BeginStmt{} - } -| "START" "TRANSACTION" - { - $$ = &ast.BeginStmt{} - } - -ColumnDef: - ColumnName Type ColumnOptionListOpt - { - $$ = &ast.ColumnDef{Name: $1.(*ast.ColumnName), Tp: $2.(*types.FieldType), Options: $3.([]*ast.ColumnOption)} - } - -ColumnName: - Identifier - { - $$ = &ast.ColumnName{Name: model.NewCIStr($1.(string))} - } -| Identifier '.' Identifier - { - $$ = &ast.ColumnName{Table: model.NewCIStr($1.(string)), Name: model.NewCIStr($3.(string))} - } -| Identifier '.' Identifier '.' Identifier - { - $$ = &ast.ColumnName{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string)), Name: model.NewCIStr($5.(string))} - } - -ColumnNameList: - ColumnName - { - $$ = []*ast.ColumnName{$1.(*ast.ColumnName)} - } -| ColumnNameList ',' ColumnName - { - $$ = append($1.([]*ast.ColumnName), $3.(*ast.ColumnName)) - } - -ColumnNameListOpt: - /* EMPTY */ - { - $$ = []*ast.ColumnName{} - } -| '(' ')' - { - $$ = []*ast.ColumnName{} - } -| '(' ColumnNameList ')' - { - $$ = $2.([]*ast.ColumnName) - } - -CommitStmt: - "COMMIT" - { - $$ = &ast.CommitStmt{} - } - -ColumnOption: - "NOT" "NULL" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNotNull} - } -| "NULL" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNull} - } -| "AUTO_INCREMENT" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionAutoIncrement} - } -| "PRIMARY" "KEY" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionPrimaryKey} - } -| "UNIQUE" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniq} - } -| "UNIQUE" "KEY" - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniqKey} - } -| "DEFAULT" DefaultValueExpr - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionDefaultValue, Expr: $2.(ast.ExprNode)} - } -| "ON" "UPDATE" NowSym - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionOnUpdate, Expr: $3.(ast.ExprNode)} - } -| "COMMENT" stringLit - { - $$ = &ast.ColumnOption{Tp: ast.ColumnOptionComment} - } -| "CHECK" '(' Expression ')' - { - // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html - // The CHECK clause is parsed but ignored by all storage engines. - $$ = nil - } - -ColumnOptionList: - ColumnOption - { - $$ = []*ast.ColumnOption{$1.(*ast.ColumnOption)} - } -| ColumnOptionList ColumnOption - { - $$ = append($1.([]*ast.ColumnOption), $2.(*ast.ColumnOption)) - } - -ColumnOptionListOpt: - { - $$ = []*ast.ColumnOption{} - } -| ColumnOptionList - { - $$ = $1.([]*ast.ColumnOption) - } - -ConstraintElem: - "PRIMARY" "KEY" '(' IndexColNameList ')' - { - $$ = &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: $4.([]*ast.IndexColName)} - } -| "FULLTEXT" "KEY" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintFulltext, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string), - } - } -| "INDEX" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintIndex, - Keys: $4.([]*ast.IndexColName), - Name: $2.(string), - } - } -| "KEY" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintKey, - Keys: $4.([]*ast.IndexColName), - Name: $2.(string)} - } -| "UNIQUE" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintUniq, - Keys: $4.([]*ast.IndexColName), - Name: $2.(string)} - } -| "UNIQUE" "INDEX" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintUniqIndex, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string)} - } -| "UNIQUE" "KEY" IndexName '(' IndexColNameList ')' - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintUniqKey, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string)} - } -| "FOREIGN" "KEY" IndexName '(' IndexColNameList ')' ReferDef - { - $$ = &ast.Constraint{ - Tp: ast.ConstraintForeignKey, - Keys: $5.([]*ast.IndexColName), - Name: $3.(string), - Refer: $7.(*ast.ReferenceDef), - } - } - -ReferDef: - "REFERENCES" TableName '(' IndexColNameList ')' - { - $$ = &ast.ReferenceDef{Table: $2.(*ast.TableName), IndexColNames: $4.([]*ast.IndexColName)} - } - -/* - * The DEFAULT clause specifies a default value for a column. - * With one exception, the default value must be a constant; - * it cannot be a function or an expression. This means, for example, - * that you cannot set the default for a date column to be the value of - * a function such as NOW() or CURRENT_DATE. The exception is that you - * can specify CURRENT_TIMESTAMP as the default for a TIMESTAMP or DATETIME column. - * - * See: http://dev.mysql.com/doc/refman/5.7/en/create-table.html - * https://github.com/mysql/mysql-server/blob/5.7/sql/sql_yacc.yy#L6832 - */ -DefaultValueExpr: - NowSym -| SignedLiteral - -// TODO: Process other three keywords -NowSym: - "CURRENT_TIMESTAMP" - { - $$ = &ast.IdentifierExpr{Name: model.NewCIStr("CURRENT_TIMESTAMP")} - } -| "LOCALTIME" -| "LOCALTIMESTAMP" -| "NOW" - -SignedLiteral: - Literal - { - $$ = ast.ValueExpr{Val: $1} - } -| '+' NumLiteral - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: &ast.ValueExpr{Val: $2}} - } -| '-' NumLiteral - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: &ast.ValueExpr{Val: $2}} - } - -// TODO: support decimal literal -NumLiteral: - intLit -| floatLit - - -CreateIndexStmt: - "CREATE" CreateIndexStmtUnique "INDEX" Identifier "ON" TableName '(' IndexColNameList ')' - { - $$ = &ast.CreateIndexStmt{ - Unique: $2.(bool), - IndexName: $4.(string), - Table: $6.(*ast.TableName), - IndexColNames: $8.([]*ast.IndexColName), - } - if yylex.(*lexer).root { - break - } - } - -CreateIndexStmtUnique: - { - $$ = false - } -| "UNIQUE" - { - $$ = true - } - -IndexColName: - ColumnName OptFieldLen Order - { - //Order is parsed but just ignored as MySQL did - $$ = &ast.IndexColName{Column: $1.(*ast.ColumnName), Length: $2.(int)} - } - -IndexColNameList: - { - $$ = []*ast.IndexColName{} - } -| IndexColName - { - $$ = []*ast.IndexColName{$1.(*ast.IndexColName)} - } -| IndexColNameList ',' IndexColName - { - $$ = append($1.([]*ast.IndexColName), $3.(*ast.IndexColName)) - } - - - -/******************************************************************* - * - * Create Database Statement - * CREATE {DATABASE | SCHEMA} [IF NOT EXISTS] db_name - * [create_specification] ... - * - * create_specification: - * [DEFAULT] CHARACTER SET [=] charset_name - * | [DEFAULT] COLLATE [=] collation_name - *******************************************************************/ -CreateDatabaseStmt: - "CREATE" DatabaseSym IfNotExists DBName DatabaseOptionListOpt - { - $$ = &ast.CreateDatabaseStmt{ - IfNotExists: $3.(bool), - Name: $4.(string), - Options: $5.([]*ast.DatabaseOption), - } - - if yylex.(*lexer).root { - break - } - } - -DBName: - Identifier - -DatabaseOption: - DefaultKwdOpt CharsetKw EqOpt StringName - { - $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCharset, Value: $4.(string)} - } -| DefaultKwdOpt "COLLATE" EqOpt StringName - { - $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: $4.(string)} - } - -DatabaseOptionListOpt: - { - $$ = []*ast.DatabaseOption{} - } -| DatabaseOptionList - -DatabaseOptionList: - DatabaseOption - { - $$ = []*ast.DatabaseOption{$1.(*ast.DatabaseOption)} - } -| DatabaseOptionList DatabaseOption - { - $$ = append($1.([]*ast.DatabaseOption), $2.(*ast.DatabaseOption)) - } - -/******************************************************************* - * - * Create Table Statement - * - * Example: - * CREATE TABLE Persons - * ( - * P_Id int NOT NULL, - * LastName varchar(255) NOT NULL, - * FirstName varchar(255), - * Address varchar(255), - * City varchar(255), - * PRIMARY KEY (P_Id) - * ) - *******************************************************************/ -CreateTableStmt: - "CREATE" "TABLE" IfNotExists TableName '(' TableElementList ')' TableOptionListOpt - { - tes := $6.([]interface {}) - var columnDefs []*ast.ColumnDef - var constraints []*ast.Constraint - for _, te := range tes { - switch te := te.(type) { - case *ast.ColumnDef: - columnDefs = append(columnDefs, te) - case *ast.Constraint: - constraints = append(constraints, te) - } - } - if len(columnDefs) == 0 { - yylex.(*lexer).err("Column Definition List can't be empty.") - return 1 - } - $$ = &ast.CreateTableStmt{ - Table: $4.(*ast.TableName), - IfNotExists: $3.(bool), - Cols: columnDefs, - Constraints: constraints, - Options: $8.([]*ast.TableOption), - } - } - -Default: - "DEFAULT" Expression - { - $$ = $2 - } - -DefaultOpt: - { - $$ = nil - } -| Default - -DefaultKwdOpt: - {} -| "DEFAULT" - -/****************************************************************** - * Do statement - * See: https://dev.mysql.com/doc/refman/5.7/en/do.html - ******************************************************************/ -DoStmt: - "DO" ExpressionList - { - $$ = &ast.DoStmt { - Exprs: $2.([]ast.ExprNode), - } - } - -/******************************************************************* - * - * Delete Statement - * - *******************************************************************/ -DeleteFromStmt: - "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause - { - // Single Table - x := &ast.DeleteStmt{ - TableRefs: &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil}, - LowPriority: $2.(bool), - Quick: $3.(bool), - Ignore: $4.(bool), - Order: $8.([]*ast.OrderByItem), - } - if $7 != nil { - x.Where = $7.(ast.ExprNode) - } - - if $9 != nil { - x.Limit = $9.(*ast.Limit) - } - - $$ = x - if yylex.(*lexer).root { - break - } - } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional TableNameList "FROM" TableRefs WhereClauseOptional - { - // Multiple Table - x := &ast.DeleteStmt{ - LowPriority: $2.(bool), - Quick: $3.(bool), - Ignore: $4.(bool), - MultiTable: true, - BeforeFrom: true, - Tables: $5.([]*ast.TableName), - TableRefs: $7.(*ast.Join), - } - if $8 != nil { - x.Where = $8.(ast.ExprNode) - } - $$ = x - if yylex.(*lexer).root { - break - } - } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableNameList "USING" TableRefs WhereClauseOptional - { - // Multiple Table - x := &ast.DeleteStmt{ - LowPriority: $2.(bool), - Quick: $3.(bool), - Ignore: $4.(bool), - MultiTable: true, - Tables: $6.([]*ast.TableName), - TableRefs: $8.(*ast.Join), - } - if $9 != nil { - x.Where = $9.(ast.ExprNode) - } - $$ = x - if yylex.(*lexer).root { - break - } - } - -DatabaseSym: - "DATABASE" | "SCHEMA" - -DropDatabaseStmt: - "DROP" DatabaseSym IfExists DBName - { - $$ = &ast.DropDatabaseStmt{IfExists: $3.(bool), Name: $4.(string)} - if yylex.(*lexer).root { - break - } - } - -DropIndexStmt: - "DROP" "INDEX" IfExists Identifier "ON" TableName - { - $$ = &ast.DropIndexStmt{IfExists: $3.(bool), IndexName: $4.(string), Table: $6.(*ast.TableName)} - } - -DropTableStmt: - "DROP" "TABLE" TableNameList - { - $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} - if yylex.(*lexer).root { - break - } - } -| "DROP" "TABLE" "IF" "EXISTS" TableNameList - { - $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} - if yylex.(*lexer).root { - break - } - } - -EqOpt: - { - } -| eq - { - } - -EmptyStmt: - /* EMPTY */ - { - $$ = nil - } - -ExplainSym: - "EXPLAIN" -| "DESCRIBE" -| "DESC" - -ExplainStmt: - ExplainSym TableName - { - $$ = &ast.ExplainStmt{ - Stmt: &ast.ShowStmt{ - Tp: ast.ShowTables, - Table: $2.(*ast.TableName), - }, - } - } -| ExplainSym TableName ColumnName - { - $$ = &ast.ExplainStmt{ - Stmt: &ast.ShowStmt{ - Tp: ast.ShowColumns, - Table: $2.(*ast.TableName), - Column: $3.(*ast.ColumnName), - }, - } - } -| ExplainSym ExplainableStmt - { - $$ = &ast.ExplainStmt{Stmt: $2.(ast.StmtNode)} - } - -LengthNum: - NUM - { - switch v := $1.(type) { - case int64: - $$ = uint64(v) - case uint64: - $$ = uint64(v) - } - } - -NUM: - intLit - -Expression: - Expression logOr Expression %prec oror - { - $$ = &ast.BinaryOperationExpr{Op: opcode.OrOr, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| Expression "XOR" Expression %prec xor - { - $$ = &ast.BinaryOperationExpr{Op: opcode.LogicXor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| Expression logAnd Expression %prec andand - { - $$ = &ast.BinaryOperationExpr{Op: opcode.AndAnd, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| "NOT" Expression %prec not - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} - } -| Factor "IS" NotOpt trueKwd %prec is - { - $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(1)} - } -| Factor "IS" NotOpt falseKwd %prec is - { - $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(0)} - } -| Factor "IS" NotOpt "UNKNOWN" %prec is - { - /* https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#operator_is */ - $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} - } -| Factor - - -logOr: - "||" - { - } -| "OR" - { - } - -logAnd: - "&&" - { - } -| "AND" - { - } - -name: - Identifier - -ExpressionList: - Expression - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| ExpressionList ',' Expression - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) - } - -ExpressionListOpt: - { - $$ = []ast.ExprNode{} - } -| ExpressionList - -Factor: - Factor "IS" NotOpt "NULL" %prec is - { - $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} - } -| Factor CompareOp PredicateExpr %prec eq - { - $$ = &ast.BinaryOperationExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| Factor CompareOp AnyOrAll SubSelect %prec eq - { - $$ = &ast.CompareSubqueryExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $4.(*ast.SubqueryExpr), All: $3.(bool)} - } -| PredicateExpr - -CompareOp: - ">=" - { - $$ = opcode.GE - } -| '>' - { - $$ = opcode.GT - } -| "<=" - { - $$ = opcode.LE - } -| '<' - { - $$ = opcode.LT - } -| "!=" - { - $$ = opcode.NE - } -| "<>" - { - $$ = opcode.NE - } -| "=" - { - $$ = opcode.EQ - } -| "<=>" - { - $$ = opcode.NullEQ - } - -AnyOrAll: - "ANY" - { - $$ = false - } -| "SOME" - { - $$ = false - } -| "ALL" - { - $$ = true - } - -PredicateExpr: - PrimaryFactor NotOpt "IN" '(' ExpressionList ')' - { - $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), List: $5.([]ast.ExprNode)} - } -| PrimaryFactor NotOpt "IN" SubSelect - { - $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), Sel: $4.(*ast.SubqueryExpr)} - } -| PrimaryFactor NotOpt "BETWEEN" PrimaryFactor "AND" PredicateExpr - { - $$ = &ast.BetweenExpr{ - Expr: $1.(ast.ExprNode), - Left: $4.(ast.ExprNode), - Right: $6.(ast.ExprNode), - Not: $2.(bool), - } - } -| PrimaryFactor NotOpt "LIKE" PrimaryExpression LikeEscapeOpt - { - escape := $5.(string) - if len(escape) > 1 { - yylex.(*lexer).errf("Incorrect arguments %s to ESCAPE", escape) - return 1 - } else if len(escape) == 0 { - escape = "\\" - } - $$ = &ast.PatternLikeExpr{ - Expr: $1.(ast.ExprNode), - Pattern: $4.(ast.ExprNode), - Not: $2.(bool), - Escape: escape[0], - } - } -| PrimaryFactor NotOpt RegexpSym PrimaryExpression - { - $$ = &ast.PatternRegexpExpr{Expr: $1.(ast.ExprNode), Pattern: $4.(ast.ExprNode), Not: $2.(bool)} - } -| PrimaryFactor - -RegexpSym: - "REGEXP" -| "RLIKE" - -LikeEscapeOpt: - %prec lowerThanEscape - { - $$ = "\\" - } -| "ESCAPE" stringLit - { - $$ = $2 - } - -NotOpt: - { - $$ = false - } -| "NOT" - { - $$ = true - } - -Field: - '*' - { - $$ = &ast.SelectField{WildCard: &ast.WildCardField{}} - } -| Identifier '.' '*' - { - tn := &ast.TableName{Name:model.NewCIStr($1.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} - } -| Identifier '.' Identifier '.' '*' - { - tn := &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} - } -| Expression FieldAsNameOpt - { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: $2.(model.CIStr)} - } - -FieldAsNameOpt: - /* EMPTY */ - { - $$ = model.CIStr{} - } -| FieldAsName - { - $$ = $1 - } - -FieldAsName: - Identifier - { - $$ = $1 - } -| "AS" Identifier - { - $$ = $2 - } -| stringLit - { - $$ = $1 - } -| "AS" stringLit - { - $$ = $2 - } - -FieldList: - Field - { - $$ = []*ast.SelectField{$1.(*ast.SelectField)} - } -| FieldList ',' Field - { - $$ = append($1.([]*ast.SelectField), $3.(*ast.SelectField)) - } - -GroupByClause: - "GROUP" "BY" GroupByList - { - $$ = $3.([]ast.ExprNode) - } - -GroupByList: - Expression - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| GroupByList ',' Expression - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) - } - -HavingClause: - { - $$ = nil - } -| "HAVING" Expression - { - $$ = $2.(ast.ExprNode) - } - -IfExists: - { - $$ = false - } -| "IF" "EXISTS" - { - $$ = true - } - -IfNotExists: - { - $$ = false - } -| "IF" "NOT" "EXISTS" - { - $$ = true - } - - -IgnoreOptional: - { - $$ = false - } -| "IGNORE" - { - $$ = true - } - -IndexName: - { - $$ = "" - } -| Identifier - { - //"index name" - $$ = $1.(string) - } - -IndexType: - Identifier - { - // TODO: "index type" - $$ = $1.(string) - } - -/**********************************Identifier********************************************/ -Identifier: - identifier | UnReservedKeyword | NotKeywordToken - -UnReservedKeyword: - "AUTO_INCREMENT" | "AFTER" | "AVG" | "BEGIN" | "BIT" | "BOOL" | "BOOLEAN" | "CHARSET" | "COLUMNS" | "COMMIT" -| "DATE" | "DATETIME" | "DEALLOCATE" | "DO" | "END" | "ENGINE" | "ENGINES" | "EXECUTE" | "FIRST" | "FULL" -| "LOCAL" | "NAMES" | "OFFSET" | "PASSWORD" %prec lowerThanEq | "PREPARE" | "QUICK" | "ROLLBACK" | "SESSION" | "SIGNED" -| "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" -| "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" -| "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" - -NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" -| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" -| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" - -/************************************************************************************ - * - * Insert Statments - * - * TODO: support PARTITION - **********************************************************************************/ -InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableName InsertRest OnDuplicateKeyUpdate - { - x := $6.(*ast.InsertStmt) - x.Priority = $2.(int) - x.Table = $5.(*ast.TableName) - if $7 != nil { - x.OnDuplicate = $7.([]*ast.Assignment) - } - $$ = x - if yylex.(*lexer).root { - break - } - } - -IntoOpt: - { - } -| "INTO" - { - } - -InsertRest: - '(' ColumnNameListOpt ')' ValueSym ExpressionListList - { - $$ = &ast.InsertStmt{ - Columns: $2.([]*ast.ColumnName), - Lists: $5.([][]ast.ExprNode)} - } -| '(' ColumnNameListOpt ')' SelectStmt - { - $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} - } -| ValueSym ExpressionListList %prec insertValues - { - $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} - } -| SelectStmt - { - $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} - } -| "SET" ColumnSetValueList - { - $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} - } - -ValueSym: - "VALUE" -| "VALUES" - -ExpressionListList: - '(' ')' - { - $$ = [][]ast.ExprNode{[]ast.ExprNode{}} - } -| '(' ')' ',' ExpressionListList - { - $$ = append([][]ast.ExprNode{[]ast.ExprNode{}}, $4.([][]ast.ExprNode)...) - } -| '(' ExpressionList ')' - { - $$ = [][]ast.ExprNode{$2.([]ast.ExprNode)} - } -| '(' ExpressionList ')' ',' ExpressionListList - { - $$ = append([][]ast.ExprNode{$2.([]ast.ExprNode)}, $5.([][]ast.ExprNode)...) - } - -ColumnSetValue: - ColumnName eq Expression - { - $$ = &ast.Assignment{ - Column: $1.(*ast.ColumnName), - Expr: $3.(ast.ExprNode), - } - } - -ColumnSetValueList: - { - $$ = []*ast.Assignment{} - } -| ColumnSetValue - { - $$ = []*ast.Assignment{$1.(*ast.Assignment)} - } -| ColumnSetValueList ',' ColumnSetValue - { - $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) - } - -/* - * ON DUPLICATE KEY UPDATE col_name=expr [, col_name=expr] ... - * See: https://dev.mysql.com/doc/refman/5.7/en/insert-on-duplicate.html - */ -OnDuplicateKeyUpdate: - { - $$ = nil - } -| "ON" "DUPLICATE" "KEY" "UPDATE" AssignmentList - { - $$ = $5 - } - -/***********************************Insert Statments END************************************/ - -Literal: - "false" - { - $$ = int64(0) - } -| "NULL" -| "true" - { - $$ = int64(1) - } -| floatLit -| intLit -| stringLit -| hexLit -| bitLit - -Operand: - Literal - { - $$ = &ast.ValueExpr{Val: $1} - } -| ColumnName - { - $$ = &ast.ColumnNameExpr{Name: $1.(*ast.ColumnName)} - } -| '(' Expression ')' - { - $$ = &ast.ParenthesesExpr{Expr: $2.(ast.ExprNode)} - } -| "DEFAULT" %prec lowerThanLeftParen - { - $$ = &ast.DefaultExpr{} - } -| "DEFAULT" '(' ColumnName ')' - { - $$ = &ast.DefaultExpr{Name: $3.(*ast.ColumnName)} - } -| Variable - { - $$ = $1 - } -| "PLACEHOLDER" - { - $$ = &ast.ParamMarkerExpr{} - } -| "ROW" '(' Expression ',' ExpressionList ')' - { - values := append([]ast.ExprNode{$3.(ast.ExprNode)}, $5.([]ast.ExprNode)...) - $$ = &ast.RowExpr{Values: values} - } -| '(' Expression ',' ExpressionList ')' - { - values := append([]ast.ExprNode{$2.(ast.ExprNode)}, $4.([]ast.ExprNode)...) - $$ = &ast.RowExpr{Values: values} - } -| "EXISTS" SubSelect - { - $$ = &ast.ExistsSubqueryExpr{Sel: $2.(*ast.SubqueryExpr)} - } - -OrderBy: - "ORDER" "BY" OrderByList - { - $$ = $3.([]*ast.OrderByItem) - } - -OrderByList: - OrderByItem - { - $$ = []*ast.OrderByItem{$1.(*ast.OrderByItem)} - } -| OrderByList ',' OrderByItem - { - $$ = append($1.([]*ast.OrderByItem), $3.(*ast.OrderByItem)) - } - -OrderByItem: - Expression Order - { - $$ = &ast.OrderByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} - } - -Order: - /* EMPTY */ - { - $$ = false // ASC by default - } -| "ASC" - { - $$ = false - } -| "DESC" - { - $$ = true - } - -OrderByOptional: - { - $$ = nil - } -| OrderBy - { - $$ = $1 - } - -PrimaryExpression: - Operand -| Function -| SubSelect -| '!' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} - } -| '~' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.BitNeg, V: $2.(ast.ExprNode)} - } -| '-' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: $2.(ast.ExprNode)} - } -| '+' PrimaryExpression %prec neg - { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: $2.(ast.ExprNode)} - } -| "BINARY" PrimaryExpression %prec neg - { - // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#operator_binary - x := types.NewFieldType(mysql.TypeString) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = &ast.FuncCastExpr{ - Expr: $2.(ast.ExprNode), - Tp: x, - FunctionType: ast.CastBinaryOperator, - } - } -| PrimaryExpression "COLLATE" StringName %prec neg - { - // TODO: Create a builtin function hold expr and collation. When do evaluation, convert expr result using the collation. - $$ = $1 - } - -Function: - FunctionCallKeyword -| FunctionCallNonKeyword -| FunctionCallConflict -| FunctionCallAgg - -FunctionNameConflict: - "DATABASE" | "SCHEMA" | "IF" | "LEFT" | "REPEAT" | "CURRENT_USER" | "CURRENT_DATE" - -FunctionCallConflict: - FunctionNameConflict '(' ExpressionListOpt ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: false} - } -| "CURRENT_USER" - { - // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - $$ = &ast.FuncCallExpr{F: $1.(string)} - } -| "CURRENT_DATE" - { - $$ = &ast.FuncCallExpr{F: $1.(string)} - } - -DistinctOpt: - { - $$ = false - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } -| "DISTINCT" "ALL" - { - $$ = true - } - -FunctionCallKeyword: - "AVG" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "CAST" '(' Expression "AS" CastType ')' - { - /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ - $$ = &ast.FuncCastExpr{ - Expr: $3.(ast.ExprNode), - Tp: $5.(*types.FieldType), - FunctionType: ast.CastFunction, - } - } -| "CASE" ExpressionOpt WhenClauseList ElseOpt "END" - { - x := &ast.CaseExpr{WhenClauses: $3.([]*ast.WhenClause)} - if $2 != nil { - x.Value = $2.(ast.ExprNode) - } - if $4 != nil { - x.ElseClause = $4.(ast.ExprNode) - } - $$ = x - } -| "CONVERT" '(' Expression "USING" StringName ')' - { - // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &ast.FuncConvertExpr{ - Expr: $3.(ast.ExprNode), - Charset: $5.(string), - } - } -| "CONVERT" '(' Expression ',' CastType ')' - { - // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &ast.FuncCastExpr{ - Expr: $3.(ast.ExprNode), - Tp: $5.(*types.FieldType), - FunctionType: ast.CastConvertFunction, - } - } -| "DATE" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "USER" '(' ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string)} - } -| "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues - { - // TODO: support qualified identifier for column_name - $$ = &ast.ColumnNameExpr{Name: $3.(*ast.ColumnName)} - } -| "WEEK" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "YEAR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } - -FunctionCallNonKeyword: - "COALESCE" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "CURDATE" '(' ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string)} - } -| "CURRENT_TIMESTAMP" FuncDatetimePrec - { - args := []ast.ExprNode{} - if $2 != nil { - args = append(args, $2.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} - } -| "ABS" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "CONCAT" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "CONCAT_WS" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "DAY" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DAYOFWEEK" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DAYOFMONTH" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "DAYOFYEAR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "EXTRACT" '(' TimeUnit "FROM" Expression ')' - { - $$ = &ast.FuncExtractExpr{ - Unit: $3.(string), - Date: $5.(ast.ExprNode), - } - } -| "FOUND_ROWS" '(' ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string)} - } -| "HOUR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "IFNULL" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "LENGTH" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "LOCATE" '(' Expression ',' Expression ')' - { - $$ = &ast.FuncLocateExpr{ - SubStr: $3.(ast.ExprNode), - Str: $5.(ast.ExprNode), - } - } -| "LOCATE" '(' Expression ',' Expression ',' Expression ')' - { - $$ = &ast.FuncLocateExpr{ - SubStr: $3.(ast.ExprNode), - Str: $5.(ast.ExprNode), - Pos: $7.(ast.ExprNode), - } - } -| "LOWER" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "MICROSECOND" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "MINUTE" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "MONTH" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "NOW" '(' ExpressionOpt ')' - { - args := []ast.ExprNode{} - if $3 != nil { - args = append(args, $3.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} - } -| "NULLIF" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } -| "RAND" '(' ExpressionOpt ')' - { - - args := []ast.ExprNode{} - if $3 != nil { - args = append(args, $3.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} - } -| "SECOND" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "SUBSTRING" '(' Expression ',' Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - } - } -| "SUBSTRING" '(' Expression "FROM" Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - } - } -| "SUBSTRING" '(' Expression ',' Expression ',' Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - Len: $7.(ast.ExprNode), - } - } -| "SUBSTRING" '(' Expression "FROM" Expression "FOR" Expression ')' - { - $$ = &ast.FuncSubstringExpr{ - StrExpr: $3.(ast.ExprNode), - Pos: $5.(ast.ExprNode), - Len: $7.(ast.ExprNode), - } - } -| "SUBSTRING_INDEX" '(' Expression ',' Expression ',' Expression ')' - { - $$ = &ast.FuncSubstringIndexExpr{ - StrExpr: $3.(ast.ExprNode), - Delim: $5.(ast.ExprNode), - Count: $7.(ast.ExprNode), - } - } -| "SYSDATE" '(' ExpressionOpt ')' - { - args := []ast.ExprNode{} - if $3 != nil { - args = append(args, $3.(ast.ExprNode)) - } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} - } -| "TRIM" '(' Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $3.(ast.ExprNode), - } - } -| "TRIM" '(' Expression "FROM" Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $5.(ast.ExprNode), - RemStr: $3.(ast.ExprNode), - } - } -| "TRIM" '(' TrimDirection "FROM" Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $5.(ast.ExprNode), - Direction: $3.(ast.TrimDirectionType), - } - } -| "TRIM" '(' TrimDirection Expression "FROM" Expression ')' - { - $$ = &ast.FuncTrimExpr{ - Str: $6.(ast.ExprNode), - RemStr: $4.(ast.ExprNode), - Direction: $3.(ast.TrimDirectionType), - } - } -| "UPPER" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "WEEKDAY" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "WEEKOFYEAR" '(' Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} - } -| "YEARWEEK" '(' ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} - } - -TrimDirection: - "BOTH" - { - $$ = ast.TrimBoth - } -| "LEADING" - { - $$ = ast.TrimLeading - } -| "TRAILING" - { - $$ = ast.TrimTrailing - } - -FunctionCallAgg: - "COUNT" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "COUNT" '(' DistinctOpt '*' ')' - { - args := []ast.ExprNode{&ast.ValueExpr{Val: ast.TypeStar("*")} } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} - } -| "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "MAX" '(' DistinctOpt Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} - } -| "MIN" '(' DistinctOpt Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} - } -| "SUM" '(' DistinctOpt Expression ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} - } - -FuncDatetimePrec: - { - $$ = nil - } -| '(' ')' - { - $$ = nil - } -| '(' Expression ')' - { - $$ = $2 - } - -TimeUnit: - "MICROSECOND" | "SECOND" | "MINUTE" | "HOUR" | "DAY" | "WEEK" -| "MONTH" | "QUARTER" | "YEAR" | "SECOND_MICROSECOND" | "MINUTE_MICROSECOND" -| "MINUTE_SECOND" | "HOUR_MICROSECOND" | "HOUR_SECOND" | "HOUR_MINUTE" -| "DAY_MICROSECOND" | "DAY_SECOND" | "DAY_MINUTE" | "DAY_HOUR" | "YEAR_MONTH" - -ExpressionOpt: - { - $$ = nil - } -| Expression - { - $$ = $1 - } - -WhenClauseList: - WhenClause - { - $$ = []*ast.WhenClause{$1.(*ast.WhenClause)} - } -| WhenClauseList WhenClause - { - $$ = append($1.([]*ast.WhenClause), $2.(*ast.WhenClause)) - } - -WhenClause: - "WHEN" Expression "THEN" Expression - { - $$ = &ast.WhenClause{ - Expr: $2.(ast.ExprNode), - Result: $4.(ast.ExprNode), - } - } - -ElseOpt: - /* empty */ - { - $$ = nil - } -| "ELSE" Expression - { - $$ = $2 - } - -CastType: - "BINARY" OptFieldLen - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "CHAR" OptFieldLen OptBinary OptCharset - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $2.(int) - if $3.(bool) { - x.Flag |= mysql.BinaryFlag - } - x.Charset = $4.(string) - $$ = x - } -| "DATE" - { - x := types.NewFieldType(mysql.TypeDate) - $$ = x - } -| "DATETIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDatetime) - x.Decimal = $2.(int) - $$ = x - } -| "DECIMAL" FloatOpt - { - fopt := $2.(*ast.FloatOpt) - x := types.NewFieldType(mysql.TypeNewDecimal) - x.Flen = fopt.Flen - x.Decimal = fopt.Decimal - $$ = x - } -| "TIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDuration) - x.Decimal = $2.(int) - $$ = x - } -| "SIGNED" OptInteger - { - x := types.NewFieldType(mysql.TypeLonglong) - $$ = x - } -| "UNSIGNED" OptInteger - { - x := types.NewFieldType(mysql.TypeLonglong) - x.Flag |= mysql.UnsignedFlag - $$ = x - } - - -PrimaryFactor: - PrimaryFactor '|' PrimaryFactor %prec '|' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Or, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '&' PrimaryFactor %prec '&' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.And, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor "<<" PrimaryFactor %prec lsh - { - $$ = &ast.BinaryOperationExpr{Op: opcode.LeftShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor ">>" PrimaryFactor %prec rsh - { - $$ = &ast.BinaryOperationExpr{Op: opcode.RightShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '+' PrimaryFactor %prec '+' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Plus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '-' PrimaryFactor %prec '-' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Minus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '*' PrimaryFactor %prec '*' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Mul, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '/' PrimaryFactor %prec '/' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Div, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '%' PrimaryFactor %prec '%' - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor "DIV" PrimaryFactor %prec div - { - $$ = &ast.BinaryOperationExpr{Op: opcode.IntDiv, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor "MOD" PrimaryFactor %prec mod - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryFactor '^' PrimaryFactor - { - $$ = &ast.BinaryOperationExpr{Op: opcode.Xor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} - } -| PrimaryExpression - - -Priority: - { - $$ = ast.NoPriority - } -| "LOW_PRIORITY" - { - $$ = ast.LowPriority - } -| "HIGH_PRIORITY" - { - $$ = ast.HighPriority - } -| "DELAYED" - { - $$ = ast.DelayedPriority - } - -LowPriorityOptional: - { - $$ = false - } -| "LOW_PRIORITY" - { - $$ = true - } - -TableName: - Identifier - { - $$ = &ast.TableName{Name:model.NewCIStr($1.(string))} - } -| Identifier '.' Identifier - { - $$ = &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - } - -TableNameList: - TableName - { - tbl := []*ast.TableName{$1.(*ast.TableName)} - $$ = tbl - } -| TableNameList ',' TableName - { - $$ = append($1.([]*ast.TableName), $3.(*ast.TableName)) - } - -QuickOptional: - %prec lowerThanQuick - { - $$ = false - } -| "QUICK" - { - $$ = true - } - -/***************************Prepared Statement Start****************************** - * See: https://dev.mysql.com/doc/refman/5.7/en/prepare.html - * Example: - * PREPARE stmt_name FROM 'SELECT SQRT(POW(?,2) + POW(?,2)) AS hypotenuse'; - * OR - * SET @s = 'SELECT SQRT(POW(?,2) + POW(?,2)) AS hypotenuse'; - * PREPARE stmt_name FROM @s; - */ - -PreparedStmt: - "PREPARE" Identifier "FROM" PrepareSQL - { - var sqlText string - var sqlVar *ast.VariableExpr - switch $4.(type) { - case string: - sqlText = $4.(string) - case *ast.VariableExpr: - sqlVar = $4.(*ast.VariableExpr) - } - $$ = &ast.PrepareStmt{ - InPrepare: true, - Name: $2.(string), - SQLText: sqlText, - SQLVar: sqlVar, - } - } - -PrepareSQL: - stringLit -| UserVariable - - -/* - * See: https://dev.mysql.com/doc/refman/5.7/en/execute.html - * Example: - * EXECUTE stmt1 USING @a, @b; - * OR - * EXECUTE stmt1; - */ -ExecuteStmt: - "EXECUTE" Identifier - { - $$ = &ast.ExecuteStmt{Name: $2.(string)} - } -| "EXECUTE" Identifier "USING" UserVariableList - { - $$ = &ast.ExecuteStmt{ - Name: $2.(string), - UsingVars: $4.([]ast.ExprNode), - } - } - -UserVariableList: - UserVariable - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| UserVariableList ',' UserVariable - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) - } - -/* - * See: https://dev.mysql.com/doc/refman/5.0/en/deallocate-prepare.html - */ - -DeallocateStmt: - DeallocateSym "PREPARE" Identifier - { - $$ = &ast.DeallocateStmt{Name: $3.(string)} - } - -DeallocateSym: - "DEALLOCATE" | "DROP" - -/****************************Prepared Statement End*******************************/ - - -RollbackStmt: - "ROLLBACK" - { - $$ = &ast.RollbackStmt{} - } - -SelectStmt: - SelectBasic -| SelectBasic UnionClauseList - { - st := $1.(*ast.SelectStmt) - st.Unions = $2.([]*ast.UnionClause) - $$ = st - } -| SubSelect UnionClausePList OrderByOptional SelectStmtLimit - { - st := $1.(*ast.SubqueryExpr).Query - st.Unions = $2.([]*ast.UnionClause) - st.UnionOrderBy = $3.([]*ast.OrderByItem) - st.UnionLimit = $4.(*ast.Limit) - $$ = st - } - -SelectBasic: - "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt - { - $$ = &ast.SelectStmt { - Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), - From: nil, - LockTp: $6.(ast.SelectLockType), - } - } -| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional - SelectStmtLimit SelectLockOpt - { - st := &ast.SelectStmt{ - Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), - From: $5.(*ast.Join), - LockTp: $11.(ast.SelectLockType), - } - - if $6 != nil { - st.Where = $6.(ast.ExprNode) - } - - if $7 != nil { - st.GroupBy = $7.([]ast.ExprNode) - } - - if $8 != nil { - st.Having = $8.(ast.ExprNode) - } - - if $9 != nil { - st.OrderBy = $9.([]*ast.OrderByItem) - } - - if $10 != nil { - st.Limit = $10.(*ast.Limit) - } - - $$ = st - } - -FromDual: - /* Empty */ -| "FROM" "DUAL" - - -FromClause: - TableRefs - { - $$ = $1 - } - -TableRefs: - EscapedTableRef - { - if j, ok := $1.(*ast.Join); ok { - // if $1 is JoinRset, use it directly - $$ = j - } else { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} - } - } -| TableRefs ',' EscapedTableRef - { - /* from a, b is default cross join */ - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} - } - -EscapedTableRef: - TableRef %prec lowerThanSetKeyword - { - $$ = $1 - } -| '{' Identifier TableRef '}' - { - /* - * ODBC escape syntax for outer join is { OJ join_table } - * Use an Identifier for OJ - */ - $$ = $3 - } - -TableRef: - TableFactor - { - $$ = $1 - } -| JoinTable - { - $$ = $1 - } - -TableFactor: - TableName TableAsNameOpt - { - $$ = &ast.TableSource{Source: $1.(*ast.TableName), AsName: $2.(model.CIStr)} - } -| '(' SelectStmt ')' TableAsName - { - $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} - } -| '(' TableRefs ')' - { - $$ = $2 - } - -TableAsNameOpt: - { - $$ = model.CIStr{} - } -| TableAsName - { - $$ = $1 - } - -TableAsName: - Identifier - { - $$ = model.NewCIStr($1.(string)) - } -| "AS" Identifier - { - $$ = model.NewCIStr($2.(string)) - } - -JoinTable: - /* Use %prec to evaluate production TableRef before cross join */ - TableRef CrossOpt TableRef %prec tableRefPriority - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} - } -| TableRef CrossOpt TableRef "ON" Expression - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: $5.(ast.ExprNode)} - } -| TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: $7.(ast.ExprNode)} - } - /* Support Using */ - -JoinType: - "LEFT" - { - $$ = ast.LeftJoin - } -| "RIGHT" - { - $$ = ast.RightJoin - } - -OuterOpt: - { - $$ = nil - } -| "OUTER" - - -CrossOpt: - "JOIN" -| "CROSS" "JOIN" -| "INNER" "JOIN" - - -LimitClause: - { - $$ = nil - } -| "LIMIT" LengthNum - { - $$ = &ast.Limit{Count: $2.(uint64)} - } - -SelectStmtLimit: - { - $$ = nil - } -| "LIMIT" LengthNum - { - $$ = &ast.Limit{Count: $2.(uint64)} - } -| "LIMIT" LengthNum ',' LengthNum - { - $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} - } -| "LIMIT" LengthNum "OFFSET" LengthNum - { - $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} - } - -SelectStmtDistinct: - /* EMPTY */ - { - $$ = false - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } - -SelectStmtOpts: - SelectStmtDistinct SelectStmtCalcFoundRows - { - // TODO: return calc_found_rows opt and support more other options - $$ = $1 - } - -SelectStmtCalcFoundRows: - %prec lowerThanCalcFoundRows - { - $$ = false - } -| "SQL_CALC_FOUND_ROWS" - { - $$ = true - } - -SelectStmtFieldList: - FieldList - { - $$ = $1 - } - -SelectStmtGroup: - /* EMPTY */ - { - $$ = nil - } -| GroupByClause - -// See: https://dev.mysql.com/doc/refman/5.7/en/subqueries.html -SubSelect: - '(' SelectStmt ')' - { - s := $2.(*ast.SelectStmt) - src := yylex.(*lexer).src - // See the implemention of yyParse function - s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &ast.SubqueryExpr{Query: s} - } - -// See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html -SelectLockOpt: - /* empty */ - { - $$ = ast.SelectLockNone - } -| "FOR" "UPDATE" - { - $$ = ast.SelectLockForUpdate - } -| "LOCK" "IN" "SHARE" "MODE" - { - $$ = ast.SelectLockInShareMode - } - -// See: https://dev.mysql.com/doc/refman/5.7/en/union.html -UnionClause: - "UNION" UnionOpt SelectBasic - { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SelectStmt)} - } - -UnionClauseP: - "UNION" UnionOpt SubSelect - { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SubqueryExpr).Query} - } - -UnionClauseList: - UnionClause - { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} - } -| UnionClauseList UnionClause - { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) - } - -UnionClausePList: - UnionClauseP - { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} - } -| UnionClausePList UnionClauseP - { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) - } - -UnionOpt: - { - $$ = true - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } - - -/********************Set Statement*******************************/ -SetStmt: - "SET" VariableAssignmentList - { - $$ = &ast.SetStmt{Variables: $2.([]*ast.VariableAssignment)} - } -| "SET" "NAMES" StringName - { - $$ = &ast.SetCharsetStmt{Charset: $3.(string)} - } -| "SET" "NAMES" StringName "COLLATE" StringName - { - $$ = &ast.SetCharsetStmt{ - Charset: $3.(string), - Collate: $5.(string), - } - } -| "SET" CharsetKw StringName - { - $$ = &ast.SetCharsetStmt{Charset: $3.(string)} - } -| "SET" "PASSWORD" eq PasswordOpt - { - $$ = &ast.SetPwdStmt{Password: $4.(string)} - } -| "SET" "PASSWORD" "FOR" Username eq PasswordOpt - { - $$ = &ast.SetPwdStmt{User: $4.(string), Password: $6.(string)} - } - -VariableAssignment: - Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $1.(string), Value: $3.(ast.ExprNode), IsSystem: true} - } -| "GLOBAL" Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsGlobal: true, IsSystem: true} - } -| "SESSION" Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} - } -| "LOCAL" Identifier eq Expression - { - $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} - } -| "SYS_VAR" eq Expression - { - v := strings.ToLower($1.(string)) - var isGlobal bool - if strings.HasPrefix(v, "@@global.") { - isGlobal = true - v = strings.TrimPrefix(v, "@@global.") - } else if strings.HasPrefix(v, "@@session.") { - v = strings.TrimPrefix(v, "@@session.") - } else if strings.HasPrefix(v, "@@local.") { - v = strings.TrimPrefix(v, "@@local.") - } else if strings.HasPrefix(v, "@@") { - v = strings.TrimPrefix(v, "@@") - } - $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode), IsGlobal: isGlobal, IsSystem: true} - } -| "USER_VAR" eq Expression - { - v := $1.(string) - v = strings.TrimPrefix(v, "@") - $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode)} - } - -VariableAssignmentList: - { - $$ = []*ast.VariableAssignment{} - } -| VariableAssignment - { - $$ = []*ast.VariableAssignment{$1.(*ast.VariableAssignment)} - } -| VariableAssignmentList ',' VariableAssignment - { - $$ = append($1.([]*ast.VariableAssignment), $3.(*ast.VariableAssignment)) - } - -Variable: - SystemVariable | UserVariable - -SystemVariable: - "SYS_VAR" - { - v := strings.ToLower($1.(string)) - var isGlobal bool - if strings.HasPrefix(v, "@@global.") { - isGlobal = true - v = strings.TrimPrefix(v, "@@global.") - } else if strings.HasPrefix(v, "@@session.") { - v = strings.TrimPrefix(v, "@@session.") - } else if strings.HasPrefix(v, "@@local.") { - v = strings.TrimPrefix(v, "@@local.") - } else if strings.HasPrefix(v, "@@") { - v = strings.TrimPrefix(v, "@@") - } - $$ = &ast.VariableExpr{Name: v, IsGlobal: isGlobal, IsSystem: true} - } - -UserVariable: - "USER_VAR" - { - v := $1.(string) - v = strings.TrimPrefix(v, "@") - $$ = &ast.VariableExpr{Name: v, IsGlobal: false, IsSystem: false} - } - -Username: - stringLit "AT" stringLit - { - $$ = $1.(string) + "@" + $3.(string) - } - -PasswordOpt: - stringLit - { - $$ = $1.(string) - } -| "PASSWORD" '(' AuthString ')' - { - $$ = $3.(string) - } - -AuthString: - stringLit - { - $$ = $1.(string) - } - -/****************************Show Statement*******************************/ -ShowStmt: - "SHOW" "ENGINES" - { - $$ = &ast.ShowStmt{Tp: ast.ShowEngines} - } -| "SHOW" "DATABASES" - { - $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} - } -| "SHOW" "SCHEMAS" - { - $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} - } -| "SHOW" "CHARACTER" "SET" - { - $$ = &ast.ShowStmt{Tp: ast.ShowCharset} - } -| "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowTables, - DBName: $4.(string), - Full: $2.(bool), - Where: $5.(ast.ExprNode), - } - } -| "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowColumns, - Table: $4.(*ast.TableName), - DBName: $5.(string), - Full: $2.(bool), - } - } -| "SHOW" "WARNINGS" - { - $$ = &ast.ShowStmt{Tp: ast.ShowWarnings} - } -| "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowVariables, - GlobalScope: $2.(bool), - Where: $4.(ast.ExprNode), - } - } -| "SHOW" "COLLATION" ShowLikeOrWhereOpt - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowCollation, - Where: $3.(ast.ExprNode), - } - } -| "SHOW" "CREATE" "TABLE" TableName - { - $$ = &ast.ShowStmt{ - Tp: ast.ShowCreateTable, - Table: $4.(*ast.TableName), - } - } - -ShowLikeOrWhereOpt: - { - $$ = nil - } -| "LIKE" PrimaryExpression - { - $$ = &ast.PatternLikeExpr{Pattern: $2.(ast.ExprNode)} - } -| "WHERE" Expression - { - $$ = $2.(ast.ExprNode) - } - -GlobalScope: - { - $$ = false - } -| "GLOBAL" - { - $$ = true - } -| "SESSION" - { - $$ = false - } - -OptFull: - { - $$ = false - } -| "FULL" - { - $$ = true - } - -ShowDatabaseNameOpt: - { - $$ = "" - } -| "FROM" DBName - { - $$ = $2.(string) - } -| "IN" DBName - { - $$ = $2.(string) - } - -ShowTableAliasOpt: - "FROM" TableName - { - $$ = $2.(*ast.TableName) - } -| "IN" TableName - { - $$ = $2.(*ast.TableName) - } - -Statement: - EmptyStmt -| AlterTableStmt -| BeginTransactionStmt -| CommitStmt -| DeallocateStmt -| DeleteFromStmt -| ExecuteStmt -| ExplainStmt -| CreateDatabaseStmt -| CreateIndexStmt -| CreateTableStmt -| CreateUserStmt -| DoStmt -| DropDatabaseStmt -| DropIndexStmt -| DropTableStmt -| InsertIntoStmt -| PreparedStmt -| RollbackStmt -| SelectStmt -| SetStmt -| ShowStmt -| TruncateTableStmt -| UpdateStmt -| UseStmt -| SubSelect - { - // `(select 1)`; is a valid select statement - // TODO: This is used to fix issue #320. There may be a better solution. - $$ = $1.(*ast.SubqueryExpr).Query - } - -ExplainableStmt: - SelectStmt -| DeleteFromStmt -| UpdateStmt -| InsertIntoStmt - -StatementList: - Statement - { - if $1 != nil { - s := $1.(ast.StmtNode) - s.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = append(yylex.(*lexer).list, s) - } - } -| StatementList ';' Statement - { - if $3 != nil { - s := $3.(ast.StmtNode) - s.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = append(yylex.(*lexer).list, s) - } - } - -Constraint: - ConstraintKeywordOpt ConstraintElem - { - cst := $2.(*ast.Constraint) - if $1 != nil { - cst.Name = $1.(string) - } - $$ = cst - } - -TableElement: - ColumnDef - { - $$ = $1.(*ast.ColumnDef) - } -| Constraint - { - $$ = $1.(*ast.Constraint) - } -| "CHECK" '(' Expression ')' - { - /* Nothing to do now */ - $$ = nil - } - -TableElementList: - TableElement - { - if $1 != nil { - $$ = []interface{}{$1.(interface{})} - } else { - $$ = []interface{}{} - } - } -| TableElementList ',' TableElement - { - if $3 != nil { - $$ = append($1.([]interface{}), $3) - } else { - $$ = $1 - } - } - -TableOption: - "ENGINE" Identifier - { - $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $2.(string)} - } -| "ENGINE" eq Identifier - { - $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $3.(string)} - } -| DefaultKwdOpt CharsetKw EqOpt StringName - { - $$ = &ast.TableOption{Tp: ast.TableOptionCharset, StrValue: $4.(string)} - } -| DefaultKwdOpt "COLLATE" EqOpt StringName - { - $$ = &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: $4.(string)} - } -| "AUTO_INCREMENT" eq LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionAutoIncrement, UintValue: $3.(uint64)} - } -| "COMMENT" EqOpt stringLit - { - $$ = &ast.TableOption{Tp: ast.TableOptionComment, StrValue: $3.(string)} - } -| "AVG_ROW_LENGTH" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionAvgRowLength, UintValue: $3.(uint64)} - } -| "CONNECTION" EqOpt stringLit - { - $$ = &ast.TableOption{Tp: ast.TableOptionConnection, StrValue: $3.(string)} - } -| "CHECKSUM" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionCheckSum, UintValue: $3.(uint64)} - } -| "PASSWORD" EqOpt stringLit - { - $$ = &ast.TableOption{Tp: ast.TableOptionPassword, StrValue: $3.(string)} - } -| "COMPRESSION" EqOpt Identifier - { - $$ = &ast.TableOption{Tp: ast.TableOptionCompression, StrValue: $3.(string)} - } -| "KEY_BLOCK_SIZE" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionKeyBlockSize, UintValue: $3.(uint64)} - } -| "MAX_ROWS" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionMaxRows, UintValue: $3.(uint64)} - } -| "MIN_ROWS" EqOpt LengthNum - { - $$ = &ast.TableOption{Tp: ast.TableOptionMinRows, UintValue: $3.(uint64)} - } - - -TableOptionListOpt: - { - $$ = []*ast.TableOption{} - } -| TableOptionList %prec lowerThanComma - -TableOptionList: - TableOption - { - $$ = []*ast.TableOption{$1.(*ast.TableOption)} - } -| TableOptionList TableOption - { - $$ = append($1.([]*ast.TableOption), $2.(*ast.TableOption)) - } -| TableOptionList ',' TableOption - { - $$ = append($1.([]*ast.TableOption), $3.(*ast.TableOption)) - } - - -TruncateTableStmt: - "TRUNCATE" "TABLE" TableName - { - $$ = &ast.TruncateTableStmt{Table: $3.(*ast.TableName)} - } - -/*************************************Type Begin***************************************/ -Type: - NumericType - { - $$ = $1 - } -| StringType - { - $$ = $1 - } -| DateAndTimeType - { - $$ = $1 - } -| "float32" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "float64" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "int64" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "string" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "uint" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } -| "uint64" - { - x := types.NewFieldType($1.(byte)) - $$ = x - } - -NumericType: - IntegerType OptFieldLen FieldOpts - { - // TODO: check flen 0 - x := types.NewFieldType($1.(byte)) - x.Flen = $2.(int) - for _, o := range $3.([]*field.Opt) { - if o.IsUnsigned { - x.Flag |= mysql.UnsignedFlag - } - if o.IsZerofill { - x.Flag |= mysql.ZerofillFlag - } - } - $$ = x - } -| FixedPointType FloatOpt FieldOpts - { - fopt := $2.(*ast.FloatOpt) - x := types.NewFieldType($1.(byte)) - x.Flen = fopt.Flen - x.Decimal = fopt.Decimal - for _, o := range $3.([]*field.Opt) { - if o.IsUnsigned { - x.Flag |= mysql.UnsignedFlag - } - if o.IsZerofill { - x.Flag |= mysql.ZerofillFlag - } - } - $$ = x - } -| FloatingPointType FloatOpt FieldOpts - { - fopt := $2.(*ast.FloatOpt) - x := types.NewFieldType($1.(byte)) - x.Flen = fopt.Flen - if x.Tp == mysql.TypeFloat { - // Fix issue #312 - if x.Flen > 53 { - yylex.(*lexer).errf("Float len(%d) should not be greater than 53", x.Flen) - return 1 - } - if x.Flen > 24 { - x.Tp = mysql.TypeDouble - } - } - x.Decimal =fopt.Decimal - for _, o := range $3.([]*field.Opt) { - if o.IsUnsigned { - x.Flag |= mysql.UnsignedFlag - } - if o.IsZerofill { - x.Flag |= mysql.ZerofillFlag - } - } - $$ = x - } -| BitValueType OptFieldLen - { - x := types.NewFieldType($1.(byte)) - x.Flen = $2.(int) - if x.Flen == -1 || x.Flen == 0 { - x.Flen = 1 - } else if x.Flen > 64 { - yylex.(*lexer).errf("invalid field length %d for bit type, must in [1, 64]", x.Flen) - } - $$ = x - } - -IntegerType: - "TINYINT" - { - $$ = mysql.TypeTiny - } -| "SMALLINT" - { - $$ = mysql.TypeShort - } -| "MEDIUMINT" - { - $$ = mysql.TypeInt24 - } -| "INT" - { - $$ = mysql.TypeLong - } -| "INTEGER" - { - $$ = mysql.TypeLong - } -| "BIGINT" - { - $$ = mysql.TypeLonglong - } -| "BOOL" - { - $$ = mysql.TypeTiny - } -| "BOOLEAN" - { - $$ = mysql.TypeTiny - } - -OptInteger: - {} | "INTEGER" - -FixedPointType: - "DECIMAL" - { - $$ = mysql.TypeNewDecimal - } -| "NUMERIC" - { - $$ = mysql.TypeNewDecimal - } - -FloatingPointType: - "float" - { - $$ = mysql.TypeFloat - } -| "REAL" - { - $$ = mysql.TypeDouble - } -| "DOUBLE" - { - $$ = mysql.TypeDouble - } -| "DOUBLE" "PRECISION" - { - $$ = mysql.TypeDouble - } - -BitValueType: - "BIT" - { - $$ = mysql.TypeBit - } - -StringType: - NationalOpt "CHAR" FieldLen OptBinary OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $3.(int) - if $4.(bool) { - x.Flag |= mysql.BinaryFlag - } - $$ = x - } -| NationalOpt "CHAR" OptBinary OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeString) - if $3.(bool) { - x.Flag |= mysql.BinaryFlag - } - $$ = x - } -| NationalOpt "VARCHAR" FieldLen OptBinary OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeVarchar) - x.Flen = $3.(int) - if $4.(bool) { - x.Flag |= mysql.BinaryFlag - } - x.Charset = $5.(string) - x.Collate = $6.(string) - $$ = x - } -| "BINARY" OptFieldLen - { - x := types.NewFieldType(mysql.TypeString) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "VARBINARY" FieldLen - { - x := types.NewFieldType(mysql.TypeVarchar) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| BlobType - { - $$ = $1.(*types.FieldType) - } -| TextType OptBinary OptCharset OptCollate - { - x := $1.(*types.FieldType) - if $2.(bool) { - x.Flag |= mysql.BinaryFlag - } - x.Charset = $3.(string) - x.Collate = $4.(string) - $$ = x - } -| "ENUM" '(' StringList ')' OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeEnum) - x.Elems = $3.([]string) - x.Charset = $5.(string) - x.Collate = $6.(string) - $$ = x - } -| "SET" '(' StringList ')' OptCharset OptCollate - { - x := types.NewFieldType(mysql.TypeSet) - x.Elems = $3.([]string) - x.Charset = $5.(string) - x.Collate = $6.(string) - $$ = x - } - -NationalOpt: - { - - } -| "NATIONAL" - { - - } - -BlobType: - "TINYBLOB" - { - x := types.NewFieldType(mysql.TypeTinyBlob) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "BLOB" OptFieldLen - { - x := types.NewFieldType(mysql.TypeBlob) - x.Flen = $2.(int) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "MEDIUMBLOB" - { - x := types.NewFieldType(mysql.TypeMediumBlob) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } -| "LONGBLOB" - { - x := types.NewFieldType(mysql.TypeLongBlob) - x.Charset = charset.CharsetBin - x.Collate = charset.CharsetBin - $$ = x - } - -TextType: - "TINYTEXT" - { - x := types.NewFieldType(mysql.TypeTinyBlob) - $$ = x - - } -| "TEXT" OptFieldLen - { - x := types.NewFieldType(mysql.TypeBlob) - x.Flen = $2.(int) - $$ = x - } -| "MEDIUMTEXT" - { - x := types.NewFieldType(mysql.TypeMediumBlob) - $$ = x - } -| "LONGTEXT" - { - x := types.NewFieldType(mysql.TypeLongBlob) - $$ = x - } - - -DateAndTimeType: - "DATE" - { - x := types.NewFieldType(mysql.TypeDate) - $$ = x - } -| "DATETIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDatetime) - x.Decimal = $2.(int) - $$ = x - } -| "TIMESTAMP" OptFieldLen - { - x := types.NewFieldType(mysql.TypeTimestamp) - x.Decimal = $2.(int) - $$ = x - } -| "TIME" OptFieldLen - { - x := types.NewFieldType(mysql.TypeDuration) - x.Decimal = $2.(int) - $$ = x - } -| "YEAR" OptFieldLen - { - x := types.NewFieldType(mysql.TypeYear) - x.Flen = $2.(int) - $$ = x - } - -FieldLen: - '(' LengthNum ')' - { - $$ = int($2.(uint64)) - } - -OptFieldLen: - { - /* -1 means unspecified field length*/ - $$ = types.UnspecifiedLength - } -| FieldLen - { - $$ = $1.(int) - } - -FieldOpt: - "UNSIGNED" - { - $$ = &field.Opt{IsUnsigned: true} - } -| "ZEROFILL" - { - $$ = &field.Opt{IsZerofill: true, IsUnsigned: true} - } - -FieldOpts: - { - $$ = []*field.Opt{} - } -| FieldOpts FieldOpt - { - $$ = append($1.([]*field.Opt), $2.(*field.Opt)) - } - -FloatOpt: - { - $$ = &ast.FloatOpt{Flen: types.UnspecifiedLength, Decimal: types.UnspecifiedLength} - } -| FieldLen - { - $$ = &ast.FloatOpt{Flen: $1.(int), Decimal: types.UnspecifiedLength} - } -| Precision - { - $$ = $1.(*ast.FloatOpt) - } - -Precision: - '(' LengthNum ',' LengthNum ')' - { - $$ = &ast.FloatOpt{Flen: int($2.(uint64)), Decimal: int($4.(uint64))} - } - -OptBinary: - { - $$ = false - } -| "BINARY" - { - $$ = true - } - -OptCharset: - { - $$ = "" - } -| CharsetKw StringName - { - $$ = $2.(string) - } - -CharsetKw: - "CHARACTER" "SET" -| "CHARSET" - -OptCollate: - { - $$ = "" - } -| "COLLATE" StringName - { - $$ = $2.(string) - } - -StringList: - stringLit - { - $$ = []string{$1.(string)} - } -| StringList ',' stringLit - { - $$ = append($1.([]string), $3.(string)) - } - -StringName: - stringLit - { - $$ = $1.(string) - } -| Identifier - { - $$ = $1.(string) - } - -/*********************************************************************************** - * Update Statement - * See: https://dev.mysql.com/doc/refman/5.7/en/update.html - ***********************************************************************************/ -UpdateStmt: - "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause - { - // Single-table syntax - r := &ast.Join{Left: $4.(ast.ResultSetNode), Right: nil} - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: r, - List: $6.([]*ast.Assignment), - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - if $8 != nil { - st.Order = $8.([]*ast.OrderByItem) - } - if $9 != nil { - st.Limit = $9.(*ast.Limit) - } - $$ = st - if yylex.(*lexer).root { - break - } - } -| "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional - { - // Multiple-table syntax - st := &ast.UpdateStmt{ - LowPriority: $2.(bool), - TableRefs: $4.(*ast.Join), - List: $6.([]*ast.Assignment), - MultipleTable: true, - } - if $7 != nil { - st.Where = $7.(ast.ExprNode) - } - $$ = st - if yylex.(*lexer).root { - break - } - } - -UseStmt: - "USE" DBName - { - $$ = &ast.UseStmt{DBName: $2.(string)} - if yylex.(*lexer).root { - break - } - } - -WhereClause: - "WHERE" Expression - { - $$ = $2.(ast.ExprNode) - } - -WhereClauseOptional: - { - $$ = nil - } -| WhereClause - { - $$ = $1 - } - -CommaOpt: - { - } -| ',' - { - } - -/************************************************************************************ - * Account Management Statements - * https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html - ************************************************************************************/ -CreateUserStmt: - "CREATE" "USER" IfNotExists UserSpecList - { - // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html - $$ = &ast.CreateUserStmt{ - IfNotExists: $3.(bool), - Specs: $4.([]*ast.UserSpec), - } - } - -UserSpec: - Username AuthOption - { - $$ = &ast.UserSpec{ - User: $1.(string), - AuthOpt: $2.(*ast.AuthOption), - } - } - -UserSpecList: - UserSpec - { - $$ = []*ast.UserSpec{$1.(*ast.UserSpec)} - } -| UserSpecList ',' UserSpec - { - $$ = append($1.([]*ast.UserSpec), $3.(*ast.UserSpec)) - } - -AuthOption: - {} -| "IDENTIFIED" "BY" AuthString - { - $$ = &ast.AuthOption { - AuthString: $3.(string), - ByAuthString: true, - } - } -| "IDENTIFIED" "BY" "PASSWORD" HashString - { - $$ = &ast.AuthOption { - HashString: $4.(string), - } - } - -HashString: - stringLit -%% diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go deleted file mode 100644 index 6b32ea65fc..0000000000 --- a/ast/parser/parser_test.go +++ /dev/null @@ -1,79 +0,0 @@ -package parser - -import ( - "fmt" - "testing" - - . "github.com/pingcap/check" - "github.com/pingcap/tidb/ast" -) - -func TestT(t *testing.T) { - TestingT(t) -} - -var _ = Suite(&testParserSuite{}) - -var _ = Suite(&testParserSuite{}) - -type testParserSuite struct { -} - -func (s *testParserSuite) TestSimple(c *C) { - // Testcase for unreserved keywords - unreservedKws := []string{ - "auto_increment", "after", "begin", "bit", "bool", "boolean", "charset", "columns", "commit", - "date", "datetime", "deallocate", "do", "end", "engine", "engines", "execute", "first", "full", - "local", "names", "offset", "password", "prepare", "quick", "rollback", "session", "signed", - "start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown", - "value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified", - "collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size", - "max_rows", "min_rows", "national", "row", "quarter", "escape", - } - for _, kw := range unreservedKws { - src := fmt.Sprintf("SELECT %s FROM tbl;", kw) - l := NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - c.Assert(l.errs, HasLen, 0, Commentf("source %s", src)) - } - - // Testcase for prepared statement - src := "SELECT id+?, id+? from t;" - l := NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - c.Assert(len(l.Stmts()), Equals, 1) - - // Testcase for -- Comment and unary -- operator - src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;" - l = NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - c.Assert(len(l.Stmts()), Equals, 2) - - // Testcase for CONVERT(expr,type) - src = "SELECT CONVERT('111', SIGNED);" - l = NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - st := l.Stmts()[0] - ss, ok := st.(*ast.SelectStmt) - c.Assert(ok, IsTrue) - c.Assert(len(ss.Fields), Equals, 1) - cv, ok := ss.Fields[0].Expr.(*ast.FuncCastExpr) - c.Assert(ok, IsTrue) - c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) - - // For query start with comment - srcs := []string{ - "/* some comments */ SELECT CONVERT('111', SIGNED) ;", - "/* some comments */ /*comment*/ SELECT CONVERT('111', SIGNED) ;", - "SELECT /*comment*/ CONVERT('111', SIGNED) ;", - "SELECT CONVERT('111', /*comment*/ SIGNED) ;", - "SELECT CONVERT('111', SIGNED) /*comment*/;", - } - for _, src := range srcs { - l = NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - st = l.Stmts()[0] - ss, ok = st.(*ast.SelectStmt) - c.Assert(ok, IsTrue) - } -} diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l deleted file mode 100644 index a23a68a72a..0000000000 --- a/ast/parser/scanner.l +++ /dev/null @@ -1,1067 +0,0 @@ -%{ -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package parser - -import ( - "fmt" - "math" - "strconv" - "strings" - "unicode" - - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/mysql" -) - -type lexer struct { - c int - col int - errs []error - expr ast.ExprNode - i int - inj int - lcol int - line int - list []ast.StmtNode - ncol int - nline int - sc int - src string - val []byte - ungetBuf []byte - root bool - stmtStartPos int - stringLit []byte - - // record token's offset of the input - tokenEndOffset int - tokenStartOffset int -} - - -// NewLexer builds a new lexer. -func NewLexer(src string) (l *lexer) { - l = &lexer{ - src: src, - nline: 1, - ncol: 0, - } - l.next() - return -} - -func (l *lexer) Errors() []error { - return l.errs -} - -func (l *lexer) Stmts() []ast.StmtNode { - return l.list -} - -func (l *lexer) Expr() ast.ExprNode { - return l.expr -} - -func (l *lexer) Inj() int { - return l.inj -} - -func (l *lexer) SetInj(inj int) { - l.inj = inj -} - -func (l *lexer) Root() bool { - return l.root -} - -func (l *lexer) SetRoot(root bool) { - l.root = root -} - -func (l *lexer) unget(b byte) { - l.ungetBuf = append(l.ungetBuf, b) - l.i-- - l.ncol-- - l.tokenEndOffset-- -} - -func (l *lexer) next() int { - if un := len(l.ungetBuf); un > 0 { - nc := l.ungetBuf[0] - l.ungetBuf = l.ungetBuf[1:] - l.c = int(nc) - return l.c - } - - if l.c != 0 { - l.val = append(l.val, byte(l.c)) - } - l.c = 0 - if l.i < len(l.src) { - l.c = int(l.src[l.i]) - l.i++ - } - switch l.c { - case '\n': - l.lcol = l.ncol - l.nline++ - l.ncol = 0 - default: - l.ncol++ - } - l.tokenEndOffset++ - return l.c -} - -func (l *lexer) err0(ln, c int, arg interface{}) { - var argStr string - if arg != nil { - argStr = fmt.Sprintf(" %v", arg) - } - - err := fmt.Errorf("line %d column %d near \"%s\"%s", ln, c, l.val, argStr) - l.errs = append(l.errs, err) -} - -func (l *lexer) err(arg interface{}) { - l.err0(l.line, l.col, arg) -} - -func (l *lexer) errf(format string, args ...interface{}) { - s := fmt.Sprintf(format, args...) - l.err0(l.line, l.col, s) -} - -func (l *lexer) Error(s string) { - // Notice: ignore origin error info. - l.err(nil) -} - -func (l *lexer) stmtText() string { - endPos := l.i - if l.src[l.i-1] == '\n' { - endPos = l.i-1 // trim new line - } - if l.src[l.stmtStartPos] == '\n' { - l.stmtStartPos++ - } - - text := l.src[l.stmtStartPos:endPos] - - l.stmtStartPos = l.i - return text -} - - -func (l *lexer) Lex(lval *yySymType) (r int) { - defer func() { - lval.line, lval.col, lval.offset = l.line, l.col, l.tokenStartOffset - l.tokenStartOffset = l.tokenEndOffset - }() - const ( - INITIAL = iota - S1 - S2 - S3 - S4 - ) - - if n := l.inj; n != 0 { - l.inj = 0 - return n - } - - c0, c := 0, l.c -%} - -int_lit {decimal_lit}|{octal_lit} -decimal_lit [1-9][0-9]* -octal_lit 0[0-7]* -hex_lit 0[xX][0-9a-fA-F]+|[xX]"'"[0-9a-fA-F]+"'" -bit_lit 0[bB][01]+|[bB]"'"[01]+"'" - -float_lit {D}"."{D}?{E}?|{D}{E}|"."{D}{E}? -D [0-9]+ -E [eE][-+]?[0-9]+ - -imaginary_ilit {D}i -imaginary_lit {float_lit}i - -a [aA] -b [bB] -c [cC] -d [dD] -e [eE] -f [fF] -g [gG] -h [hH] -i [iI] -j [jJ] -k [kK] -l [lL] -m [mM] -n [nN] -o [oO] -p [pP] -q [qQ] -r [rR] -s [sS] -t [tT] -u [uU] -v [vV] -w [wW] -x [xX] -y [yY] -z [zZ] - -abs {a}{b}{s} -add {a}{d}{d} -after {a}{f}{t}{e}{r} -all {a}{l}{l} -alter {a}{l}{t}{e}{r} -and {a}{n}{d} -any {a}{n}{y} -as {a}{s} -asc {a}{s}{c} -auto_increment {a}{u}{t}{o}_{i}{n}{c}{r}{e}{m}{e}{n}{t} -avg {a}{v}{g} -avg_row_length {a}{v}{g}_{r}{o}{w}_{l}{e}{n}{g}{t}{h} -begin {b}{e}{g}{i}{n} -between {b}{e}{t}{w}{e}{e}{n} -both {b}{o}{t}{h} -by {b}{y} -case {c}{a}{s}{e} -cast {c}{a}{s}{t} -character {c}{h}{a}{r}{a}{c}{t}{e}{r} -charset {c}{h}{a}{r}{s}{e}{t} -check {c}{h}{e}{c}{k} -checksum {c}{h}{e}{c}{k}{s}{u}{m} -coalesce {c}{o}{a}{l}{e}{s}{c}{e} -collate {c}{o}{l}{l}{a}{t}{e} -collation {c}{o}{l}{l}{a}{t}{i}{o}{n} -column {c}{o}{l}{u}{m}{n} -columns {c}{o}{l}{u}{m}{n}{s} -comment {c}{o}{m}{m}{e}{n}{t} -commit {c}{o}{m}{m}{i}{t} -compression {c}{o}{m}{p}{r}{e}{s}{s}{i}{o}{n} -concat {c}{o}{n}{c}{a}{t} -concat_ws {c}{o}{n}{c}{a}{t}_{w}{s} -connection {c}{o}{n}{n}{e}{c}{t}{i}{o}{n} -constraint {c}{o}{n}{s}{t}{r}{a}{i}{n}{t} -convert {c}{o}{n}{v}{e}{r}{t} -count {c}{o}{u}{n}{t} -create {c}{r}{e}{a}{t}{e} -cross {c}{r}{o}{s}{s} -curdate {c}{u}{r}{d}{a}{t}{e} -current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} -current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} -database {d}{a}{t}{a}{b}{a}{s}{e} -databases {d}{a}{t}{a}{b}{a}{s}{e}{s} -day {d}{a}{y} -dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} -dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} -dayofyear {d}{a}{y}{o}{f}{y}{e}{a}{r} -deallocate {d}{e}{a}{l}{l}{o}{c}{a}{t}{e} -default {d}{e}{f}{a}{u}{l}{t} -delayed {d}{e}{l}{a}{y}{e}{d} -delete {d}{e}{l}{e}{t}{e} -drop {d}{r}{o}{p} -desc {d}{e}{s}{c} -describe {d}{e}{s}{c}{r}{i}{b}{e} -distinct {d}{i}{s}{t}{i}{n}{c}{t} -div {d}{i}{v} -do {d}{o} -dual {d}{u}{a}{l} -duplicate {d}{u}{p}{l}{i}{c}{a}{t}{e} -else {e}{l}{s}{e} -end {e}{n}{d} -engine {e}{n}{g}{i}{n}{e} -engines {e}{n}{g}{i}{n}{e}{s} -escape {e}{s}{c}{a}{p}{e} -execute {e}{x}{e}{c}{u}{t}{e} -exists {e}{x}{i}{s}{t}{s} -explain {e}{x}{p}{l}{a}{i}{n} -extract {e}{x}{t}{r}{a}{c}{t} -first {f}{i}{r}{s}{t} -for {f}{o}{r} -foreign {f}{o}{r}{e}{i}{g}{n} -found_rows {f}{o}{u}{n}{d}_{r}{o}{w}{s} -from {f}{r}{o}{m} -full {f}{u}{l}{l} -fulltext {f}{u}{l}{l}{t}{e}{x}{t} -global {g}{l}{o}{b}{a}{l} -group {g}{r}{o}{u}{p} -group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} -having {h}{a}{v}{i}{n}{g} -high_priority {h}{i}{g}{h}_{p}{r}{i}{o}{r}{i}{t}{y} -hour {h}{o}{u}{r} -identified {i}{d}{e}{n}{t}{i}{f}{i}{e}{d} -if {i}{f} -ifnull {i}{f}{n}{u}{l}{l} -ignore {i}{g}{n}{o}{r}{e} -in {i}{n} -index {i}{n}{d}{e}{x} -inner {i}{n}{n}{e}{r} -insert {i}{n}{s}{e}{r}{t} -into {i}{n}{t}{o} -is {i}{s} -join {j}{o}{i}{n} -key {k}{e}{y} -key_block_size {k}{e}{y}_{b}{l}{o}{c}{k}_{s}{i}{z}{e} -leading {l}{e}{a}{d}{i}{n}{g} -left {l}{e}{f}{t} -length {l}{e}{n}{g}{t}{h} -like {l}{i}{k}{e} -limit {l}{i}{m}{i}{t} -local {l}{o}{c}{a}{l} -locate {l}{o}{c}{a}{t}{e} -lock {l}{o}{c}{k} -lower {l}{o}{w}{e}{r} -low_priority {l}{o}{w}_{p}{r}{i}{o}{r}{i}{t}{y} -max_rows {m}{a}{x}_{r}{o}{w}{s} -microsecond {m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -minute {m}{i}{n}{u}{t}{e} -min_rows {m}{i}{n}_{r}{o}{w}{s} -mod {m}{o}{d} -mode {m}{o}{d}{e} -month {m}{o}{n}{t}{h} -names {n}{a}{m}{e}{s} -national {n}{a}{t}{i}{o}{n}{a}{l} -not {n}{o}{t} -offset {o}{f}{f}{s}{e}{t} -on {o}{n} -or {o}{r} -order {o}{r}{d}{e}{r} -outer {o}{u}{t}{e}{r} -password {p}{a}{s}{s}{w}{o}{r}{d} -prepare {p}{r}{e}{p}{a}{r}{e} -primary {p}{r}{i}{m}{a}{r}{y} -quarter {q}{u}{a}{r}{t}{e}{r} -quick {q}{u}{i}{c}{k} -rand {r}{a}{n}{d} -repeat {r}{e}{p}{e}{a}{t} -references {r}{e}{f}{e}{r}{e}{n}{c}{e}{s} -regexp {r}{e}{g}{e}{x}{p} -right {r}{i}{g}{h}{t} -rlike {r}{l}{i}{k}{e} -rollback {r}{o}{l}{l}{b}{a}{c}{k} -row {r}{o}{w} -schema {s}{c}{h}{e}{m}{a} -schemas {s}{c}{h}{e}{m}{a}{s} -second {s}{e}{c}{o}{n}{d} -select {s}{e}{l}{e}{c}{t} -session {s}{e}{s}{s}{i}{o}{n} -set {s}{e}{t} -share {s}{h}{a}{r}{e} -show {s}{h}{o}{w} -some {s}{o}{m}{e} -start {s}{t}{a}{r}{t} -substring {s}{u}{b}{s}{t}{r}{i}{n}{g} -substring_index {s}{u}{b}{s}{t}{r}{i}{n}{g}_{i}{n}{d}{e}{x} -sum {s}{u}{m} -sysdate {s}{y}{s}{d}{a}{t}{e} -table {t}{a}{b}{l}{e} -tables {t}{a}{b}{l}{e}{s} -then {t}{h}{e}{n} -trailing {t}{r}{a}{i}{l}{i}{n}{g} -transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n} -trim {t}{r}{i}{m} -truncate {t}{r}{u}{n}{c}{a}{t}{e} -max {m}{a}{x} -min {m}{i}{n} -unknown {u}{n}{k}{n}{o}{w}{n} -union {u}{n}{i}{o}{n} -unique {u}{n}{i}{q}{u}{e} -nullif {n}{u}{l}{l}{i}{f} -update {u}{p}{d}{a}{t}{e} -upper {u}{p}{p}{e}{r} -value {v}{a}{l}{u}{e} -values {v}{a}{l}{u}{e}{s} -variables {v}{a}{r}{i}{a}{b}{l}{e}{s} -warnings {w}{a}{r}{n}{i}{n}{g}{s} -week {w}{e}{e}{k} -weekday {w}{e}{e}{k}{d}{a}{y} -weekofyear {w}{e}{e}{k}{o}{f}{y}{e}{a}{r} -where {w}{h}{e}{r}{e} -when {w}{h}{e}{n} -xor {x}{o}{r} -yearweek {y}{e}{a}{r}{w}{e}{e}{k} - -null {n}{u}{l}{l} -false {f}{a}{l}{s}{e} -true {t}{r}{u}{e} - -calc_found_rows {s}{q}{l}_{c}{a}{l}{c}_{f}{o}{u}{n}{d}_{r}{o}{w}{s} - -current_ts {c}{u}{r}{r}{e}{n}{t}_{t}{i}{m}{e}{s}{t}{a}{m}{p} -localtime {l}{o}{c}{a}{l}{t}{i}{m}{e} -localts {l}{o}{c}{a}{l}{t}{i}{m}{e}{s}{t}{a}{m}{p} -now {n}{o}{w} - -bit {b}{i}{t} -tiny {t}{i}{n}{y} -tinyint {t}{i}{n}{y}{i}{n}{t} -smallint {s}{m}{a}{l}{l}{i}{n}{t} -mediumint {m}{e}{d}{i}{u}{m}{i}{n}{t} -int {i}{n}{t} -integer {i}{n}{t}{e}{g}{e}{r} -bigint {b}{i}{g}{i}{n}{t} -real {r}{e}{a}{l} -double {d}{o}{u}{b}{l}{e} -float {f}{l}{o}{a}{t} -decimal {d}{e}{c}{i}{m}{a}{l} -numeric {n}{u}{m}{e}{r}{i}{c} -date {d}{a}{t}{e} -time {t}{i}{m}{e} -timestamp {t}{i}{m}{e}{s}{t}{a}{m}{p} -datetime {d}{a}{t}{e}{t}{i}{m}{e} -year {y}{e}{a}{r} -char {c}{h}{a}{r} -varchar {v}{a}{r}{c}{h}{a}{r} -binary {b}{i}{n}{a}{r}{y} -varbinary {v}{a}{r}{b}{i}{n}{a}{r}{y} -tinyblob {t}{i}{n}{y}{b}{l}{o}{b} -blob {b}{l}{o}{b} -mediumblob {m}{e}{d}{i}{u}{m}{b}{l}{o}{b} -longblob {l}{o}{n}{g}{b}{l}{o}{b} -tinytext {t}{i}{n}{y}{t}{e}{x}{t} -text {t}{e}{x}{t} -mediumtext {m}{e}{d}{i}{u}{m}{t}{e}{x}{t} -longtext {l}{o}{n}{g}{t}{e}{x}{t} -enum {e}{n}{u}{m} -precision {p}{r}{e}{c}{i}{s}{i}{o}{n} - -signed {s}{i}{g}{n}{e}{d} -unsigned {u}{n}{s}{i}{g}{n}{e}{d} -zerofill {z}{e}{r}{o}{f}{i}{l}{l} - -bigrat {b}{i}{g}{r}{a}{t} -bool {b}{o}{o}{l} -boolean {b}{o}{o}{l}{e}{a}{n} -byte {b}{y}{t}{e} -duration {d}{u}{r}{a}{t}{i}{o}{n} -rune {r}{u}{n}{e} -string {s}{t}{r}{i}{n}{g} -use {u}{s}{e} -user {u}{s}{e}{r} -using {u}{s}{i}{n}{g} - -idchar0 [a-zA-Z_] -idchars {idchar0}|[0-9$] // See: https://dev.mysql.com/doc/refman/5.7/en/identifiers.html -ident {idchar0}{idchars}* - -user_var "@"{ident} -sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} - -second_microsecond {s}{e}{c}{o}{n}{d}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -minute_microsecond {m}{i}{n}{u}{t}{e}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -minute_second {m}{i}{n}{u}{t}{e}_{s}{e}{c}{o}{n}{d} -hour_microsecond {h}{o}{u}{r}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -hour_second {h}{o}{u}{r}_{s}{e}{c}{o}{n}{d} -hour_minute {h}{o}{u}{r}_{m}{i}{n}{u}{t}{e} -day_microsecond {d}{a}{y}_{m}{i}{c}{r}{o}{s}{e}{c}{o}{n}{d} -day_second {d}{a}{y}_{s}{e}{c}{o}{n}{d} -day_minute {d}{a}{y}_{m}{i}{n}{u}{t}{e} -day_hour {d}{a}{y}_{h}{o}{u}{r} -year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} - -%yyc c -%yyn c = l.next() -%yyt l.sc - -%x S1 S2 S3 S4 - -%% - l.val = l.val[:0] - c0, l.line, l.col = l.c, l.nline, l.ncol - -<*>\0 return 0 - -[ \t\n\r]+ -#.* -\/\/.* -\/\*([^*]|\*+[^*/])*\*+\/ --- l.sc = S3 -[ \t]+.* {l.sc = 0} -[^ \t] { - l.sc = 0 - l.c = '-' - n := len(l.val) - l.unget(l.val[n-1]) - return '-' - } - -{int_lit} return l.int(lval) -{float_lit} return l.float(lval) -{hex_lit} return l.hex(lval) -{bit_lit} return l.bit(lval) - -\" l.sc = S1 -' l.sc = S2 -` l.sc = S4 - -[^\"\\]* l.stringLit = append(l.stringLit, l.val...) -\\. l.stringLit = append(l.stringLit, l.val...) -\"\" l.stringLit = append(l.stringLit, '"') -\" l.stringLit = append(l.stringLit, '"') - l.sc = 0 - return l.str(lval, "\"") -[^'\\]* l.stringLit = append(l.stringLit, l.val...) -\\. l.stringLit = append(l.stringLit, l.val...) -'' l.stringLit = append(l.stringLit, '\'') -' l.stringLit = append(l.stringLit, '\'') - l.sc = 0 - return l.str(lval, "'") -[^`]* l.stringLit = append(l.stringLit, l.val...) -`` l.stringLit = append(l.stringLit, '`') -` l.sc = 0 - lval.item = string(l.stringLit) - l.stringLit = l.stringLit[0:0] - return identifier - -"&&" return andand -"&^" return andnot -"<<" return lsh -"<=" return le -"=" return eq -">=" return ge -"!=" return neq -"<>" return neq -"||" return oror -">>" return rsh -"<=>" return nulleq - -"@" return at -"?" return placeholder - -{abs} lval.item = string(l.val) - return abs -{add} return add -{after} lval.item = string(l.val) - return after -{all} return all -{alter} return alter -{and} return and -{any} lval.item = string(l.val) - return any -{asc} return asc -{as} return as -{auto_increment} lval.item = string(l.val) - return autoIncrement -{avg} lval.item = string(l.val) - return avg -{avg_row_length} lval.item = string(l.val) - return avgRowLength -{begin} lval.item = string(l.val) - return begin -{between} return between -{both} return both -{by} return by -{case} return caseKwd -{cast} return cast -{character} return character -{charset} lval.item = string(l.val) - return charsetKwd -{check} return check -{checksum} lval.item = string(l.val) - return checksum -{coalesce} lval.item = string(l.val) - return coalesce -{collate} return collate -{collation} lval.item = string(l.val) - return collation -{column} return column -{columns} lval.item = string(l.val) - return columns -{comment} lval.item = string(l.val) - return comment -{commit} lval.item = string(l.val) - return commit -{compression} lval.item = string(l.val) - return compression -{concat} lval.item = string(l.val) - return concat -{concat_ws} lval.item = string(l.val) - return concatWs -{connection} lval.item = string(l.val) - return connection -{constraint} return constraint -{convert} return convert -{count} lval.item = string(l.val) - return count -{create} return create -{cross} return cross -{curdate} lval.item = string(l.val) - return curDate -{current_date} lval.item = string(l.val) - return currentDate -{current_user} lval.item = string(l.val) - return currentUser -{database} lval.item = string(l.val) - return database -{databases} return databases -{day} lval.item = string(l.val) - return day -{dayofweek} lval.item = string(l.val) - return dayofweek -{dayofmonth} lval.item = string(l.val) - return dayofmonth -{dayofyear} lval.item = string(l.val) - return dayofyear -{day_hour} lval.item = string(l.val) - return dayHour -{day_microsecond} lval.item = string(l.val) - return dayMicrosecond -{day_minute} lval.item = string(l.val) - return dayMinute -{day_second} lval.item = string(l.val) - return daySecond -{deallocate} lval.item = string(l.val) - return deallocate -{default} return defaultKwd -{delayed} return delayed -{delete} return deleteKwd -{desc} return desc -{describe} return describe -{drop} return drop -{distinct} return distinct -{div} return div -{do} lval.item = string(l.val) - return do -{dual} return dual -{duplicate} lval.item = string(l.val) - return duplicate -{else} return elseKwd -{end} lval.item = string(l.val) - return end -{engine} lval.item = string(l.val) - return engine -{engines} lval.item = string(l.val) - return engines -{execute} lval.item = string(l.val) - return execute -{enum} return enum -{escape} lval.item = string(l.val) - return escape -{exists} return exists -{explain} return explain -{extract} lval.item = string(l.val) - return extract -{first} lval.item = string(l.val) - return first -{for} return forKwd -{foreign} return foreign -{found_rows} lval.item = string(l.val) - return foundRows -{from} return from -{full} lval.item = string(l.val) - return full -{fulltext} return fulltext -{group} return group -{group_concat} lval.item = string(l.val) - return groupConcat -{having} return having -{high_priority} return highPriority -{hour} lval.item = string(l.val) - return hour -{hour_microsecond} lval.item = string(l.val) - return hourMicrosecond -{hour_minute} lval.item = string(l.val) - return hourMinute -{hour_second} lval.item = string(l.val) - return hourSecond -{identified} lval.item = string(l.val) - return identified -{if} lval.item = string(l.val) - return ifKwd -{ifnull} lval.item = string(l.val) - return ifNull -{ignore} return ignore -{index} return index -{inner} return inner -{insert} return insert -{into} return into -{in} return in -{is} return is -{join} return join -{key} return key -{key_block_size} lval.item = string(l.val) - return keyBlockSize -{leading} return leading -{left} lval.item = string(l.val) - return left -{length} lval.item = string(l.val) - return length -{like} return like -{limit} return limit -{local} lval.item = string(l.val) - return local -{locate} lval.item = string(l.val) - return locate -{lock} return lock -{lower} lval.item = string(l.val) - return lower -{low_priority} return lowPriority -{max} lval.item = string(l.val) - return max -{max_rows} lval.item = string(l.val) - return maxRows -{microsecond} lval.item = string(l.val) - return microsecond -{min} lval.item = string(l.val) - return min -{minute} lval.item = string(l.val) - return minute -{minute_microsecond} lval.item = string(l.val) - return minuteMicrosecond -{minute_second} lval.item = string(l.val) - return minuteSecond -{min_rows} lval.item = string(l.val) - return minRows -{mod} return mod -{mode} lval.item = string(l.val) - return mode -{month} lval.item = string(l.val) - return month -{names} lval.item = string(l.val) - return names -{national} lval.item = string(l.val) - return national -{not} return not -{offset} lval.item = string(l.val) - return offset -{on} return on -{order} return order -{or} return or -{outer} return outer -{password} lval.item = string(l.val) - return password -{prepare} lval.item = string(l.val) - return prepare -{primary} return primary -{quarter} lval.item = string(l.val) - return quarter -{quick} lval.item = string(l.val) - return quick -{right} return right -{rollback} lval.item = string(l.val) - return rollback -{row} lval.item = string(l.val) - return row -{schema} lval.item = string(l.val) - return schema -{schemas} return schemas -{session} lval.item = string(l.val) - return session -{some} lval.item = string(l.val) - return some -{start} lval.item = string(l.val) - return start -{global} lval.item = string(l.val) - return global -{rand} lval.item = string(l.val) - return rand -{repeat} lval.item = string(l.val) - return repeat -{regexp} return regexp -{references} return references -{rlike} return rlike - -{sys_var} lval.item = string(l.val) - return sysVar - -{user_var} lval.item = string(l.val) - return userVar -{second} lval.item = string(l.val) - return second -{second_microsecond} lval.item= string(l.val) - return secondMicrosecond -{select} return selectKwd - -{set} return set -{share} return share -{show} return show -{substring} lval.item = string(l.val) - return substring -{substring_index} lval.item = string(l.val) - return substringIndex -{sum} lval.item = string(l.val) - return sum -{sysdate} lval.item = string(l.val) - return sysDate -{table} return tableKwd -{tables} lval.item = string(l.val) - return tables -{then} return then -{trailing} return trailing -{transaction} lval.item = string(l.val) - return transaction -{trim} lval.item = string(l.val) - return trim -{truncate} lval.item = string(l.val) - return truncate -{union} return union -{unique} return unique -{unknown} lval.item = string(l.val) - return unknown -{nullif} lval.item = string(l.val) - return nullIf -{update} return update -{upper} lval.item = string(l.val) - return upper -{use} return use -{user} lval.item = string(l.val) - return user -{using} return using -{value} lval.item = string(l.val) - return value -{values} return values -{variables} lval.item = string(l.val) - return variables -{warnings} lval.item = string(l.val) - return warnings -{week} lval.item = string(l.val) - return week -{weekday} lval.item = string(l.val) - return weekday -{weekofyear} lval.item = string(l.val) - return weekofyear -{when} return when -{where} return where -{xor} return xor -{yearweek} lval.item = string(l.val) - return yearweek -{year_month} lval.item = string(l.val) - return yearMonth - -{signed} lval.item = string(l.val) - return signed -{unsigned} return unsigned -{zerofill} return zerofill - -{null} lval.item = nil - return null - -{false} return falseKwd - -{true} return trueKwd - -{calc_found_rows} lval.item = string(l.val) - return calcFoundRows - -{current_ts} lval.item = string(l.val) - return currentTs -{localtime} return localTime -{localts} return localTs -{now} lval.item = string(l.val) - return now - -{bit} lval.item = string(l.val) - return bitType - -{tiny} lval.item = string(l.val) - return tinyIntType - -{tinyint} lval.item = string(l.val) - return tinyIntType - -{smallint} lval.item = string(l.val) - return smallIntType - -{mediumint} lval.item = string(l.val) - return mediumIntType - -{bigint} lval.item = string(l.val) - return bigIntType - -{decimal} lval.item = string(l.val) - return decimalType - -{numeric} lval.item = string(l.val) - return numericType - -{float} lval.item = string(l.val) - return floatType - -{double} lval.item = string(l.val) - return doubleType - -{precision} lval.item = string(l.val) - return precisionType - -{real} lval.item = string(l.val) - return realType - -{date} lval.item = string(l.val) - return dateType - -{time} lval.item = string(l.val) - return timeType - -{timestamp} lval.item = string(l.val) - return timestampType - -{datetime} lval.item = string(l.val) - return datetimeType - -{year} lval.item = string(l.val) - return yearType - -{char} lval.item = string(l.val) - return charType - -{varchar} lval.item = string(l.val) - return varcharType - -{binary} lval.item = string(l.val) - return binaryType - -{varbinary} lval.item = string(l.val) - return varbinaryType - -{tinyblob} lval.item = string(l.val) - return tinyblobType - -{blob} lval.item = string(l.val) - return blobType - -{mediumblob} lval.item = string(l.val) - return mediumblobType - -{longblob} lval.item = string(l.val) - return longblobType - -{tinytext} lval.item = string(l.val) - return tinytextType - -{mediumtext} lval.item = string(l.val) - return mediumtextType - -{text} lval.item = string(l.val) - return textType - -{longtext} lval.item = string(l.val) - return longtextType - -{bool} lval.item = string(l.val) - return boolType - -{boolean} lval.item = string(l.val) - return booleanType - -{byte} lval.item = string(l.val) - return byteType - -{int} lval.item = string(l.val) - return intType - -{integer} lval.item = string(l.val) - return integerType - -{ident} lval.item = string(l.val) - return identifier - -. return c0 - -%% - return int(unicode.ReplacementChar) -} - -func (l *lexer) npos() (line, col int) { - if line, col = l.nline, l.ncol; col == 0 { - line-- - col = l.lcol+1 - } - return -} - -func (l *lexer) str(lval *yySymType, pref string) int { - l.sc = 0 - // TODO: performance issue. - s := string(l.stringLit) - l.stringLit = l.stringLit[0:0] - if pref == "'" { - s = strings.Replace(s, "\\'", "'", -1) - s = strings.TrimSuffix(s, "'") + "\"" - pref = "\"" - } - v, err := strconv.Unquote(pref + s) - if err != nil { - v = strings.TrimSuffix(s, pref) - } - lval.item = v - return stringLit -} - -func (l *lexer) trimIdent(idt string) string { - idt = strings.TrimPrefix(idt, "`") - idt = strings.TrimSuffix(idt, "`") - return idt -} - -func (l *lexer) int(lval *yySymType) int { - n, err := strconv.ParseUint(string(l.val), 0, 64) - if err != nil { - l.errf("integer literal: %v", err) - return int(unicode.ReplacementChar) - } - - switch { - case n < math.MaxInt64: - lval.item = int64(n) - default: - lval.item = uint64(n) - } - return intLit -} - -func (l *lexer) float(lval *yySymType) int { - n, err := strconv.ParseFloat(string(l.val), 64) - if err != nil { - l.errf("float literal: %v", err) - return int(unicode.ReplacementChar) - } - - lval.item = float64(n) - return floatLit -} - -// https://dev.mysql.com/doc/refman/5.7/en/hexadecimal-literals.html -func (l *lexer) hex(lval *yySymType) int { - s := string(l.val) - h, err := mysql.ParseHex(s) - if err != nil { - l.errf("hexadecimal literal: %v", err) - return int(unicode.ReplacementChar) - } - lval.item = h - return hexLit -} - -// https://dev.mysql.com/doc/refman/5.7/en/bit-type.html -func (l *lexer) bit(lval *yySymType) int { - s := string(l.val) - b, err := mysql.ParseBit(s, -1) - if err != nil { - l.errf("bit literal: %v", err) - return int(unicode.ReplacementChar) - } - lval.item = b - return bitLit -} diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 540f7bdd2a..9e1df6d58c 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/optimizer" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" @@ -197,5 +198,7 @@ func statement(sql string) stmt.Statement { log.Debug("Compile", sql) lexer := parser.NewLexer(sql) parser.YYParse(lexer) - return lexer.Stmts()[0].(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, _ := compiler.Compile(lexer.Stmts()[0]) + return stm } diff --git a/make.cmd b/make.cmd index 523ac82f76..83a2dbda67 100644 --- a/make.cmd +++ b/make.cmd @@ -18,17 +18,6 @@ DEL /F /A /Q y.output golex -o parser/scanner.go parser/scanner.l -@echo [Ast-Parser] -go get github.com/qiuyesuifeng/goyacc -go get github.com/qiuyesuifeng/golex -type nul >>temp.XXXXXX -goyacc -o nul -xegen "temp.XXXXXX" parser/parser.y -goyacc -o ast/parser/parser.go -xe "temp.XXXXXX" ast/parser/parser.y -DEL /F /A /Q temp.XXXXXX -DEL /F /A /Q y.output - -golex -o ast/parser/scanner.go ast/parser/scanner.l - @echo [Build] godep go build -ldflags '%LDFLAGS%' diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go new file mode 100644 index 0000000000..452d047c47 --- /dev/null +++ b/optimizer/aggregator.go @@ -0,0 +1,26 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +// Aggregator is the interface to +// compute aggregate function result. +type Aggregator interface { + // Input adds an input value to aggregator. + // The input values are accumulated in the aggregator. + Input(in ...interface{}) error + // Output uses input values to compute the aggregated result. + Output() interface{} + // Clear clears the input values. + Clear() +} diff --git a/optimizer/compiler.go b/optimizer/compiler.go new file mode 100644 index 0000000000..a9e42cd8b0 --- /dev/null +++ b/optimizer/compiler.go @@ -0,0 +1,132 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "sort" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/stmt" +) + +// Compiler compiles ast.Node into an executable statement. +type Compiler struct { + converter *expressionConverter +} + +// Compile compiles a ast.Node into an executable statement. +func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { + validator := &validator{} + if _, ok := node.Accept(validator); !ok { + return nil, errors.Trace(validator.err) + } + + // binder := &InfoBinder{} + // if _, ok := node.Accept(validator); !ok { + // return nil, errors.Trace(binder.Err) + // } + + tpComputer := &typeComputer{} + if _, ok := node.Accept(tpComputer); !ok { + return nil, errors.Trace(tpComputer.err) + } + c := newExpressionConverter() + com.converter = c + switch v := node.(type) { + case *ast.InsertStmt: + return convertInsert(c, v) + case *ast.DeleteStmt: + return convertDelete(c, v) + case *ast.UpdateStmt: + return convertUpdate(c, v) + case *ast.SelectStmt: + return convertSelect(c, v) + case *ast.UnionStmt: + return convertUnion(c, v) + case *ast.CreateDatabaseStmt: + return convertCreateDatabase(c, v) + case *ast.DropDatabaseStmt: + return convertDropDatabase(c, v) + case *ast.CreateTableStmt: + return convertCreateTable(c, v) + case *ast.DropTableStmt: + return convertDropTable(c, v) + case *ast.CreateIndexStmt: + return convertCreateIndex(c, v) + case *ast.DropIndexStmt: + return convertDropIndex(c, v) + case *ast.AlterTableStmt: + return convertAlterTable(c, v) + case *ast.TruncateTableStmt: + return convertTruncateTable(c, v) + case *ast.ExplainStmt: + return convertExplain(c, v) + case *ast.PrepareStmt: + return convertPrepare(c, v) + case *ast.DeallocateStmt: + return convertDeallocate(c, v) + case *ast.ExecuteStmt: + return convertExecute(c, v) + case *ast.ShowStmt: + return convertShow(c, v) + case *ast.BeginStmt: + return convertBegin(c, v) + case *ast.CommitStmt: + return convertCommit(c, v) + case *ast.RollbackStmt: + return convertRollback(c, v) + case *ast.UseStmt: + return convertUse(c, v) + case *ast.SetStmt: + return convertSet(c, v) + case *ast.SetCharsetStmt: + return convertSetCharset(c, v) + case *ast.SetPwdStmt: + return convertSetPwd(c, v) + case *ast.CreateUserStmt: + return convertCreateUser(c, v) + case *ast.DoStmt: + return convertDo(c, v) + case *ast.GrantStmt: + return convertGrant(c, v) + } + return nil, nil +} + +type paramMarkers []*ast.ParamMarkerExpr + +func (p paramMarkers) Len() int { + return len(p) +} + +func (p paramMarkers) Less(i, j int) bool { + return p[i].Offset < p[j].Offset +} + +func (p paramMarkers) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +// ParamMarkers returns parameter markers for prepared statement. +func (com *Compiler) ParamMarkers() []*expression.ParamMarker { + c := com.converter + sort.Sort(c.paramMarkers) + oldMarkers := make([]*expression.ParamMarker, len(c.paramMarkers)) + for i, val := range c.paramMarkers { + oldMarkers[i] = c.exprMap[val].(*expression.ParamMarker) + } + return oldMarkers +} diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go new file mode 100644 index 0000000000..a136819ae1 --- /dev/null +++ b/optimizer/convert_expr.go @@ -0,0 +1,425 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/subquery" + "github.com/pingcap/tidb/model" + "strings" +) + +func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression.Expression, error) { + expr.Accept(converter) + if converter.err != nil { + return nil, errors.Trace(converter.err) + } + return converter.exprMap[expr], nil +} + +// expressionConverter converts ast expression to +// old expression for transition state. +type expressionConverter struct { + exprMap map[ast.Node]expression.Expression + paramMarkers paramMarkers + err error +} + +func newExpressionConverter() *expressionConverter { + return &expressionConverter{ + exprMap: map[ast.Node]expression.Expression{}, + } +} + +// Enter implements ast.Visitor interface. +func (c *expressionConverter) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +// Leave implements ast.Visitor interface. +func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + c.value(v) + case *ast.BetweenExpr: + c.between(v) + case *ast.BinaryOperationExpr: + c.binaryOperation(v) + case *ast.WhenClause: + c.whenClause(v) + case *ast.CaseExpr: + c.caseExpr(v) + case *ast.SubqueryExpr: + c.subquery(v) + case *ast.CompareSubqueryExpr: + c.compareSubquery(v) + case *ast.ColumnNameExpr: + c.columnNameExpr(v) + case *ast.DefaultExpr: + c.defaultExpr(v) + case *ast.IdentifierExpr: + c.identifier(v) + case *ast.ExistsSubqueryExpr: + c.existsSubquery(v) + case *ast.PatternInExpr: + c.patternIn(v) + case *ast.IsNullExpr: + c.isNull(v) + case *ast.IsTruthExpr: + c.isTruth(v) + case *ast.PatternLikeExpr: + c.patternLike(v) + case *ast.ParamMarkerExpr: + c.paramMarker(v) + case *ast.ParenthesesExpr: + c.parentheses(v) + case *ast.PositionExpr: + c.position(v) + case *ast.PatternRegexpExpr: + c.patternRegexp(v) + case *ast.RowExpr: + c.row(v) + case *ast.UnaryOperationExpr: + c.unaryOperation(v) + case *ast.ValuesExpr: + c.values(v) + case *ast.VariableExpr: + c.variable(v) + case *ast.FuncCallExpr: + c.funcCall(v) + case *ast.FuncExtractExpr: + c.funcExtract(v) + case *ast.FuncConvertExpr: + c.funcConvert(v) + case *ast.FuncCastExpr: + c.funcCast(v) + case *ast.FuncSubstringExpr: + c.funcSubstring(v) + case *ast.FuncLocateExpr: + c.funcLocate(v) + case *ast.FuncTrimExpr: + c.funcTrim(v) + case *ast.FuncDateArithExpr: + c.funcDateArith(v) + case *ast.AggregateFuncExpr: + c.aggregateFunc(v) + } + return in, c.err == nil +} + +func (c *expressionConverter) value(v *ast.ValueExpr) { + c.exprMap[v] = expression.Value{Val: v.GetValue()} +} + +func (c *expressionConverter) between(v *ast.BetweenExpr) { + oldExpr := c.exprMap[v.Expr] + oldLo := c.exprMap[v.Left] + oldHi := c.exprMap[v.Right] + oldBetween, err := expression.NewBetween(oldExpr, oldLo, oldHi, v.Not) + if err != nil { + c.err = err + return + } + c.exprMap[v] = oldBetween +} + +func (c *expressionConverter) binaryOperation(v *ast.BinaryOperationExpr) { + oldeLeft := c.exprMap[v.L] + oldRight := c.exprMap[v.R] + oldBinop := expression.NewBinaryOperation(v.Op, oldeLeft, oldRight) + c.exprMap[v] = oldBinop +} + +func (c *expressionConverter) whenClause(v *ast.WhenClause) { + oldExpr := c.exprMap[v.Expr] + oldResult := c.exprMap[v.Result] + oldWhenClause := &expression.WhenClause{Expr: oldExpr, Result: oldResult} + c.exprMap[v] = oldWhenClause +} + +func (c *expressionConverter) caseExpr(v *ast.CaseExpr) { + oldValue := c.exprMap[v.Value] + oldWhenClauses := make([]*expression.WhenClause, len(v.WhenClauses)) + for i, val := range v.WhenClauses { + oldWhenClauses[i] = c.exprMap[val].(*expression.WhenClause) + } + oldElse := c.exprMap[v.ElseClause] + oldCaseExpr := &expression.FunctionCase{ + Value: oldValue, + WhenClauses: oldWhenClauses, + ElseClause: oldElse, + } + c.exprMap[v] = oldCaseExpr +} + +func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { + oldSubquery := &subquery.SubQuery{} + switch x := v.Query.(type) { + case *ast.SelectStmt: + oldSelect, err := convertSelect(c, x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(c, x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldUnion + } + c.exprMap[v] = oldSubquery +} + +func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { + expr := c.exprMap[v.L] + subquery := c.exprMap[v.R] + oldCmpSubquery := expression.NewCompareSubQuery(v.Op, expr, subquery.(expression.SubQuery), v.All) + c.exprMap[v] = oldCmpSubquery +} + +func joinColumnName(columnName *ast.ColumnName) string { + var originStrs []string + if columnName.Schema.O != "" { + originStrs = append(originStrs, columnName.Schema.O) + } + if columnName.Table.O != "" { + originStrs = append(originStrs, columnName.Table.O) + } + originStrs = append(originStrs, columnName.Name.O) + return strings.Join(originStrs, ".") +} + +func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { + ident := &expression.Ident{} + ident.CIStr = model.NewCIStr(joinColumnName(v.Name)) + c.exprMap[v] = ident +} + +func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { + oldDefault := &expression.Default{} + if v.Name != nil { + oldDefault.Name = joinColumnName(v.Name) + } + c.exprMap[v] = oldDefault +} + +func (c *expressionConverter) identifier(v *ast.IdentifierExpr) { + oldIdent := &expression.Ident{} + oldIdent.CIStr = v.Name + c.exprMap[v] = oldIdent +} + +func (c *expressionConverter) existsSubquery(v *ast.ExistsSubqueryExpr) { + subquery := c.exprMap[v.Sel].(expression.SubQuery) + c.exprMap[v] = &expression.ExistsSubQuery{Sel: subquery} +} + +func (c *expressionConverter) patternIn(v *ast.PatternInExpr) { + oldPatternIn := &expression.PatternIn{Not: v.Not} + if v.Sel != nil { + oldPatternIn.Sel = c.exprMap[v.Sel].(expression.SubQuery) + } + oldPatternIn.Expr = c.exprMap[v.Expr] + if v.List != nil { + oldPatternIn.List = make([]expression.Expression, len(v.List)) + for i, v := range v.List { + oldPatternIn.List[i] = c.exprMap[v] + } + } + c.exprMap[v] = oldPatternIn +} + +func (c *expressionConverter) isNull(v *ast.IsNullExpr) { + oldIsNull := &expression.IsNull{Not: v.Not} + oldIsNull.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsNull +} + +func (c *expressionConverter) isTruth(v *ast.IsTruthExpr) { + oldIsTruth := &expression.IsTruth{Not: v.Not, True: v.True} + oldIsTruth.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsTruth +} + +func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { + oldPatternLike := &expression.PatternLike{ + Not: v.Not, + Escape: v.Escape, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternLike +} + +func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { + if c.exprMap[v] == nil { + c.exprMap[v] = &expression.ParamMarker{} + c.paramMarkers = append(c.paramMarkers, v) + } +} + +func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { + oldExpr := c.exprMap[v.Expr] + c.exprMap[v] = &expression.PExpr{Expr: oldExpr} +} + +func (c *expressionConverter) position(v *ast.PositionExpr) { + c.exprMap[v] = &expression.Position{N: v.N, Name: v.Name} +} + +func (c *expressionConverter) patternRegexp(v *ast.PatternRegexpExpr) { + oldPatternRegexp := &expression.PatternRegexp{ + Not: v.Not, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternRegexp +} + +func (c *expressionConverter) row(v *ast.RowExpr) { + oldRow := &expression.Row{} + oldRow.Values = make([]expression.Expression, len(v.Values)) + for i, val := range v.Values { + oldRow.Values[i] = c.exprMap[val] + } + c.exprMap[v] = oldRow +} + +func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { + oldUnary := &expression.UnaryOperation{ + Op: v.Op, + V: c.exprMap[v.V], + } + c.exprMap[v] = oldUnary +} + +func (c *expressionConverter) values(v *ast.ValuesExpr) { + nameStr := joinColumnName(v.Column) + c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} +} + +func (c *expressionConverter) variable(v *ast.VariableExpr) { + c.exprMap[v] = &expression.Variable{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + } +} + +func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) { + oldCall := &expression.Call{ + F: v.FnName, + } + oldCall.Args = make([]expression.Expression, len(v.Args)) + for i, val := range v.Args { + oldCall.Args[i] = c.exprMap[val] + } + c.exprMap[v] = oldCall +} + +func (c *expressionConverter) funcExtract(v *ast.FuncExtractExpr) { + oldExtract := &expression.Extract{Unit: v.Unit} + oldExtract.Date = c.exprMap[v.Date] + c.exprMap[v] = oldExtract +} + +func (c *expressionConverter) funcConvert(v *ast.FuncConvertExpr) { + c.exprMap[v] = &expression.FunctionConvert{ + Expr: c.exprMap[v.Expr], + Charset: v.Charset, + } +} + +func (c *expressionConverter) funcCast(v *ast.FuncCastExpr) { + oldCast := &expression.FunctionCast{ + Expr: c.exprMap[v.Expr], + Tp: v.Tp, + } + switch v.FunctionType { + case ast.CastBinaryOperator: + oldCast.FunctionType = expression.BinaryOperator + case ast.CastConvertFunction: + oldCast.FunctionType = expression.ConvertFunction + case ast.CastFunction: + oldCast.FunctionType = expression.CastFunction + } + c.exprMap[v] = oldCast +} + +func (c *expressionConverter) funcSubstring(v *ast.FuncSubstringExpr) { + oldSubstring := &expression.FunctionSubstring{ + Len: c.exprMap[v.Len], + Pos: c.exprMap[v.Pos], + StrExpr: c.exprMap[v.StrExpr], + } + c.exprMap[v] = oldSubstring +} + +func (c *expressionConverter) funcLocate(v *ast.FuncLocateExpr) { + oldLocate := &expression.FunctionLocate{ + Pos: c.exprMap[v.Pos], + Str: c.exprMap[v.Str], + SubStr: c.exprMap[v.SubStr], + } + c.exprMap[v] = oldLocate +} + +func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) { + oldTrim := &expression.FunctionTrim{ + Str: c.exprMap[v.Str], + RemStr: c.exprMap[v.RemStr], + } + switch v.Direction { + case ast.TrimBoth: + oldTrim.Direction = expression.TrimBoth + case ast.TrimBothDefault: + oldTrim.Direction = expression.TrimBothDefault + case ast.TrimLeading: + oldTrim.Direction = expression.TrimLeading + case ast.TrimTrailing: + oldTrim.Direction = expression.TrimTrailing + } + c.exprMap[v] = oldTrim +} + +func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) { + oldDateArith := &expression.DateArith{ + Unit: v.Unit, + Date: c.exprMap[v.Date], + Interval: c.exprMap[v.Interval], + } + switch v.Op { + case ast.DateAdd: + oldDateArith.Op = expression.DateAdd + case ast.DateSub: + oldDateArith.Op = expression.DateSub + } + c.exprMap[v] = oldDateArith +} + +func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) { + oldAggregate := &expression.Call{ + F: v.F, + Distinct: v.Distinct, + } + for _, val := range v.Args { + oldAggregate.Args = append(oldAggregate.Args, c.exprMap[val]) + } + c.exprMap[v] = oldAggregate +} diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go new file mode 100644 index 0000000000..df769237f1 --- /dev/null +++ b/optimizer/convert_stmt.go @@ -0,0 +1,1024 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/parser/coldef" + "github.com/pingcap/tidb/rset/rsets" + "github.com/pingcap/tidb/stmt" + "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/table" +) + +func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expression.Assignment, error) { + oldAssign := &expression.Assignment{ + ColName: joinColumnName(v.Column), + } + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldAssign.Expr = oldExpr + return oldAssign, nil +} + +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (stmt.Statement, error) { + insertValues := stmts.InsertValues{} + insertValues.Priority = v.Priority + tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) + insertValues.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + for _, val := range v.Columns { + insertValues.ColNames = append(insertValues.ColNames, joinColumnName(val)) + } + for _, row := range v.Lists { + var oldRow []expression.Expression + for _, val := range row { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldRow = append(oldRow, oldExpr) + } + insertValues.Lists = append(insertValues.Lists, oldRow) + } + for _, assign := range v.Setlist { + oldAssign, err := convertAssignment(converter, assign) + if err != nil { + return nil, errors.Trace(err) + } + insertValues.Setlist = append(insertValues.Setlist, oldAssign) + } + + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + insertValues.Sel, err = convertSelect(converter, x) + case *ast.UnionStmt: + insertValues.Sel, err = convertUnion(converter, x) + } + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Replace { + return &stmts.ReplaceIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), + }, nil + } + oldInsert := &stmts.InsertIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), + } + for _, onDup := range v.OnDuplicate { + oldOnDup, err := convertAssignment(converter, onDup) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, oldOnDup) + } + return oldInsert, nil +} + +func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { + oldDelete := &stmts.DeleteStmt{ + BeforeFrom: v.BeforeFrom, + Ignore: v.Ignore, + LowPriority: v.LowPriority, + MultiTable: v.MultiTable, + Quick: v.Quick, + Text: v.Text(), + } + oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Refs = oldRefs + for _, val := range v.Tables { + tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} + oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + } + if v.Where != nil { + oldDelete.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Order != nil { + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Order = oldOrderBy + } + if v.Limit != nil { + oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} + } + return oldDelete, nil +} + +func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { + oldUpdate := &stmts.UpdateStmt{ + Ignore: v.Ignore, + MultipleTable: v.MultipleTable, + LowPriority: v.LowPriority, + Text: v.Text(), + } + var err error + oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + if v.Where != nil { + oldUpdate.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + for _, val := range v.List { + var oldAssign *expression.Assignment + oldAssign, err = convertAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.List = append(oldUpdate.List, oldAssign) + } + if v.Order != nil { + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.Order = oldOrderBy + } + if v.Limit != nil { + oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} + } + return oldUpdate, nil +} + +func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { + oldSelect := &stmts.SelectStmt{ + Distinct: s.Distinct, + Text: s.Text(), + } + oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) + for i, val := range s.Fields.Fields { + oldField := &field.Field{} + oldField.AsName = val.AsName.O + var err error + if val.Expr != nil { + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + // TODO: handle parenthesesed column name expression, which should not set AsName. + if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { + oldField.AsName = val.Text() + } + } else if val.WildCard != nil { + str := "*" + if val.WildCard.Table.O != "" { + str = val.WildCard.Table.O + ".*" + if val.WildCard.Schema.O != "" { + str = val.WildCard.Schema.O + "." + str + } + } + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr(str)} + } + oldSelect.Fields[i] = oldField + } + var err error + if s.From != nil { + oldSelect.From, err = convertJoin(converter, s.From.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Where != nil { + oldSelect.Where = &rsets.WhereRset{} + oldSelect.Where.Expr, err = convertExpr(converter, s.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + + if s.GroupBy != nil { + oldSelect.GroupBy, err = convertGroupBy(converter, s.GroupBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Having != nil { + oldSelect.Having, err = convertHaving(converter, s.Having) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.OrderBy != nil { + oldSelect.OrderBy, err = convertOrderBy(converter, s.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Limit != nil { + if s.Limit.Offset > 0 { + oldSelect.Offset = &rsets.OffsetRset{Count: s.Limit.Offset} + } + if s.Limit.Count > 0 { + oldSelect.Limit = &rsets.LimitRset{Count: s.Limit.Count} + } + } + switch s.LockTp { + case ast.SelectLockForUpdate: + oldSelect.Lock = coldef.SelectLockForUpdate + case ast.SelectLockInShareMode: + oldSelect.Lock = coldef.SelectLockInShareMode + case ast.SelectLockNone: + oldSelect.Lock = coldef.SelectLockNone + } + return oldSelect, nil +} + +func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { + oldUnion := &stmts.UnionStmt{ + Text: u.Text(), + } + oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) + oldUnion.Distincts = make([]bool, len(u.Selects)-1) + if u.Distinct { + for i := range oldUnion.Distincts { + oldUnion.Distincts[i] = true + } + } + for i, val := range u.Selects { + oldSelect, err := convertSelect(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.Selects[i] = oldSelect + } + if u.OrderBy != nil { + oldOrderBy, err := convertOrderBy(converter, u.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.OrderBy = oldOrderBy + } + if u.Limit != nil { + if u.Limit.Offset > 0 { + oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} + } + if u.Limit.Count > 0 { + oldUnion.Limit = &rsets.LimitRset{Count: u.Limit.Count} + } + } + // Union order by can not converted to old because it is pushed to select statements. + return oldUnion, nil +} + +func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRset, error) { + oldJoin := &rsets.JoinRset{} + switch join.Tp { + case ast.CrossJoin: + oldJoin.Type = rsets.CrossJoin + case ast.LeftJoin: + oldJoin.Type = rsets.LeftJoin + case ast.RightJoin: + oldJoin.Type = rsets.RightJoin + } + switch l := join.Left.(type) { + case *ast.Join: + oldLeft, err := convertJoin(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + case *ast.TableSource: + oldLeft, err := convertTableSource(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + } + + switch r := join.Right.(type) { + case *ast.Join: + oldRight, err := convertJoin(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case *ast.TableSource: + oldRight, err := convertTableSource(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case nil: + } + + if join.On != nil { + oldOn, err := convertExpr(converter, join.On.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.On = oldOn + } + return oldJoin, nil +} + +func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*rsets.TableSource, error) { + oldTs := &rsets.TableSource{} + oldTs.Name = ts.AsName.O + switch src := ts.Source.(type) { + case *ast.TableName: + oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} + case *ast.SelectStmt: + oldSelect, err := convertSelect(converter, src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(converter, src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldUnion + } + return oldTs, nil +} + +func convertGroupBy(converter *expressionConverter, gb *ast.GroupByClause) (*rsets.GroupByRset, error) { + oldGroupBy := &rsets.GroupByRset{ + By: make([]expression.Expression, len(gb.Items)), + } + for i, val := range gb.Items { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldGroupBy.By[i] = oldExpr + } + return oldGroupBy, nil +} + +func convertHaving(converter *expressionConverter, having *ast.HavingClause) (*rsets.HavingRset, error) { + oldHaving := &rsets.HavingRset{} + oldExpr, err := convertExpr(converter, having.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldHaving.Expr = oldExpr + return oldHaving, nil +} + +func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) (*rsets.OrderByRset, error) { + oldOrderBy := &rsets.OrderByRset{} + oldOrderBy.By = make([]rsets.OrderByItem, len(orderBy.Items)) + for i, val := range orderBy.Items { + oldByItemExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldOrderBy.By[i].Expr = oldByItemExpr + oldOrderBy.By[i].Asc = !val.Desc + } + return oldOrderBy, nil +} + +func convertCreateDatabase(converter *expressionConverter, v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { + oldCreateDatabase := &stmts.CreateDatabaseStmt{ + IfNotExists: v.IfNotExists, + Name: v.Name, + Text: v.Text(), + } + if len(v.Options) != 0 { + oldCreateDatabase.Opt = &coldef.CharsetOpt{} + for _, val := range v.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + oldCreateDatabase.Opt.Chs = val.Value + case ast.DatabaseOptionCollate: + oldCreateDatabase.Opt.Col = val.Value + } + } + } + return oldCreateDatabase, nil +} + +func convertDropDatabase(converter *expressionConverter, v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { + return &stmts.DropDatabaseStmt{ + IfExists: v.IfExists, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (*coldef.ConstraintOpt, error) { + oldColumnOpt := &coldef.ConstraintOpt{} + switch v.Tp { + case ast.ColumnOptionAutoIncrement: + oldColumnOpt.Tp = coldef.ConstrAutoIncrement + case ast.ColumnOptionComment: + oldColumnOpt.Tp = coldef.ConstrComment + case ast.ColumnOptionDefaultValue: + oldColumnOpt.Tp = coldef.ConstrDefaultValue + case ast.ColumnOptionIndex: + oldColumnOpt.Tp = coldef.ConstrIndex + case ast.ColumnOptionKey: + oldColumnOpt.Tp = coldef.ConstrKey + case ast.ColumnOptionFulltext: + oldColumnOpt.Tp = coldef.ConstrFulltext + case ast.ColumnOptionNotNull: + oldColumnOpt.Tp = coldef.ConstrNotNull + case ast.ColumnOptionNoOption: + oldColumnOpt.Tp = coldef.ConstrNoConstr + case ast.ColumnOptionOnUpdate: + oldColumnOpt.Tp = coldef.ConstrOnUpdate + case ast.ColumnOptionPrimaryKey: + oldColumnOpt.Tp = coldef.ConstrPrimaryKey + case ast.ColumnOptionNull: + oldColumnOpt.Tp = coldef.ConstrNull + case ast.ColumnOptionUniq: + oldColumnOpt.Tp = coldef.ConstrUniq + case ast.ColumnOptionUniqIndex: + oldColumnOpt.Tp = coldef.ConstrUniqIndex + case ast.ColumnOptionUniqKey: + oldColumnOpt.Tp = coldef.ConstrUniqKey + } + if v.Expr != nil { + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldColumnOpt.Evalue = oldExpr + } + return oldColumnOpt, nil +} + +func convertColumnDef(converter *expressionConverter, v *ast.ColumnDef) (*coldef.ColumnDef, error) { + oldColDef := &coldef.ColumnDef{ + Name: v.Name.Name.O, + Tp: v.Tp, + } + for _, val := range v.Options { + oldOpt, err := convertColumnOption(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldColDef.Constraints = append(oldColDef.Constraints, oldOpt) + } + return oldColDef, nil +} + +func convertIndexColNames(v []*ast.IndexColName) (out []*coldef.IndexColName) { + for _, val := range v { + oldIndexColKey := &coldef.IndexColName{ + ColumnName: val.Column.Name.O, + Length: val.Length, + } + out = append(out, oldIndexColKey) + } + return +} + +func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*coldef.TableConstraint, error) { + oldConstraint := &coldef.TableConstraint{ConstrName: v.Name} + switch v.Tp { + case ast.ConstraintNoConstraint: + oldConstraint.Tp = coldef.ConstrNoConstr + case ast.ConstraintPrimaryKey: + oldConstraint.Tp = coldef.ConstrPrimaryKey + case ast.ConstraintKey: + oldConstraint.Tp = coldef.ConstrKey + case ast.ConstraintIndex: + oldConstraint.Tp = coldef.ConstrIndex + case ast.ConstraintUniq: + oldConstraint.Tp = coldef.ConstrUniq + case ast.ConstraintUniqKey: + oldConstraint.Tp = coldef.ConstrUniqKey + case ast.ConstraintUniqIndex: + oldConstraint.Tp = coldef.ConstrUniqIndex + case ast.ConstraintForeignKey: + oldConstraint.Tp = coldef.ConstrForeignKey + case ast.ConstraintFulltext: + oldConstraint.Tp = coldef.ConstrFulltext + } + oldConstraint.Keys = convertIndexColNames(v.Keys) + if v.Refer != nil { + oldConstraint.Refer = &coldef.ReferenceDef{ + TableIdent: table.Ident{Schema: v.Refer.Table.Schema, Name: v.Refer.Table.Name}, + IndexColNames: convertIndexColNames(v.Refer.IndexColNames), + } + } + return oldConstraint, nil +} + +func convertCreateTable(converter *expressionConverter, v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { + oldCreateTable := &stmts.CreateTableStmt{ + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + IfNotExists: v.IfNotExists, + Text: v.Text(), + } + for _, val := range v.Cols { + oldColDef, err := convertColumnDef(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Cols = append(oldCreateTable.Cols, oldColDef) + } + for _, val := range v.Constraints { + oldConstr, err := convertConstraint(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Constraints = append(oldCreateTable.Constraints, oldConstr) + } + if len(v.Options) != 0 { + oldTableOpt := &coldef.TableOption{} + for _, val := range v.Options { + switch val.Tp { + case ast.TableOptionEngine: + oldTableOpt.Engine = val.StrValue + case ast.TableOptionCharset: + oldTableOpt.Charset = val.StrValue + case ast.TableOptionCollate: + oldTableOpt.Collate = val.StrValue + case ast.TableOptionAutoIncrement: + oldTableOpt.AutoIncrement = val.UintValue + } + } + oldCreateTable.Opt = oldTableOpt + } + return oldCreateTable, nil +} + +func convertDropTable(converter *expressionConverter, v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { + oldDropTable := &stmts.DropTableStmt{ + IfExists: v.IfExists, + Text: v.Text(), + } + oldDropTable.TableIdents = make([]table.Ident, len(v.Tables)) + for i, val := range v.Tables { + oldDropTable.TableIdents[i] = table.Ident{ + Schema: val.Schema, + Name: val.Name, + } + } + return oldDropTable, nil +} + +func convertCreateIndex(converter *expressionConverter, v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { + oldCreateIndex := &stmts.CreateIndexStmt{ + IndexName: v.IndexName, + Unique: v.Unique, + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + } + oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) + for i, val := range v.IndexColNames { + oldIndexColName := &coldef.IndexColName{ + ColumnName: joinColumnName(val.Column), + Length: val.Length, + } + oldCreateIndex.IndexColNames[i] = oldIndexColName + } + return oldCreateIndex, nil +} + +func convertDropIndex(converter *expressionConverter, v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { + return &stmts.DropIndexStmt{ + IfExists: v.IfExists, + IndexName: v.IndexName, + Text: v.Text(), + }, nil +} + +func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec) (*ddl.AlterSpecification, error) { + oldAlterSpec := &ddl.AlterSpecification{ + Name: v.Name, + } + switch v.Tp { + case ast.AlterTableAddConstraint: + oldAlterSpec.Action = ddl.AlterAddConstr + case ast.AlterTableAddColumn: + oldAlterSpec.Action = ddl.AlterAddColumn + case ast.AlterTableDropColumn: + oldAlterSpec.Action = ddl.AlterDropColumn + case ast.AlterTableDropForeignKey: + oldAlterSpec.Action = ddl.AlterDropForeignKey + case ast.AlterTableDropIndex: + oldAlterSpec.Action = ddl.AlterDropIndex + case ast.AlterTableDropPrimaryKey: + oldAlterSpec.Action = ddl.AlterDropPrimaryKey + case ast.AlterTableOption: + oldAlterSpec.Action = ddl.AlterTableOpt + } + if v.Column != nil { + oldColDef, err := convertColumnDef(converter, v.Column) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Column = oldColDef + } + if v.Position != nil { + oldAlterSpec.Position = &ddl.ColumnPosition{} + switch v.Position.Tp { + case ast.ColumnPositionNone: + oldAlterSpec.Position.Type = ddl.ColumnPositionNone + case ast.ColumnPositionFirst: + oldAlterSpec.Position.Type = ddl.ColumnPositionFirst + case ast.ColumnPositionAfter: + oldAlterSpec.Position.Type = ddl.ColumnPositionAfter + } + if v.Position.RelativeColumn != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.Position.RelativeColumn) + } + } + if v.DropColumn != nil { + oldAlterSpec.Name = joinColumnName(v.DropColumn) + } + if v.Constraint != nil { + oldConstraint, err := convertConstraint(converter, v.Constraint) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Constraint = oldConstraint + } + for _, val := range v.Options { + oldOpt := &coldef.TableOpt{ + StrValue: val.StrValue, + UintValue: val.UintValue, + } + switch val.Tp { + case ast.TableOptionNone: + oldOpt.Tp = coldef.TblOptNone + case ast.TableOptionEngine: + oldOpt.Tp = coldef.TblOptEngine + case ast.TableOptionCharset: + oldOpt.Tp = coldef.TblOptCharset + case ast.TableOptionCollate: + oldOpt.Tp = coldef.TblOptCollate + case ast.TableOptionAutoIncrement: + oldOpt.Tp = coldef.TblOptAutoIncrement + case ast.TableOptionComment: + oldOpt.Tp = coldef.TblOptComment + case ast.TableOptionAvgRowLength: + oldOpt.Tp = coldef.TblOptAvgRowLength + case ast.TableOptionCheckSum: + oldOpt.Tp = coldef.TblOptCheckSum + case ast.TableOptionCompression: + oldOpt.Tp = coldef.TblOptCompression + case ast.TableOptionConnection: + oldOpt.Tp = coldef.TblOptConnection + case ast.TableOptionPassword: + oldOpt.Tp = coldef.TblOptPassword + case ast.TableOptionKeyBlockSize: + oldOpt.Tp = coldef.TblOptKeyBlockSize + case ast.TableOptionMaxRows: + oldOpt.Tp = coldef.TblOptMaxRows + case ast.TableOptionMinRows: + oldOpt.Tp = coldef.TblOptMinRows + } + oldAlterSpec.TableOpts = append(oldAlterSpec.TableOpts, oldOpt) + } + return oldAlterSpec, nil +} + +func convertAlterTable(converter *expressionConverter, v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { + oldAlterTable := &stmts.AlterTableStmt{ + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertAlterTableSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterTable.Specs = append(oldAlterTable.Specs, oldSpec) + } + return oldAlterTable, nil +} + +func convertTruncateTable(converter *expressionConverter, v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { + return &stmts.TruncateTableStmt{ + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + }, nil +} + +func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { + oldExplain := &stmts.ExplainStmt{ + Text: v.Text(), + } + var err error + switch x := v.Stmt.(type) { + case *ast.SelectStmt: + oldExplain.S, err = convertSelect(converter, x) + case *ast.UpdateStmt: + oldExplain.S, err = convertUpdate(converter, x) + case *ast.DeleteStmt: + oldExplain.S, err = convertDelete(converter, x) + case *ast.InsertStmt: + oldExplain.S, err = convertInsert(converter, x) + case *ast.ShowStmt: + oldExplain.S, err = convertShow(converter, x) + } + if err != nil { + return nil, errors.Trace(err) + } + return oldExplain, nil +} + +func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { + oldPrepare := &stmts.PreparedStmt{ + InPrepare: true, + Name: v.Name, + SQLText: v.SQLText, + Text: v.Text(), + } + if v.SQLVar != nil { + oldSQLVar, err := convertExpr(converter, v.SQLVar) + if err != nil { + return nil, errors.Trace(err) + } + oldPrepare.SQLVar = oldSQLVar.(*expression.Variable) + } + return oldPrepare, nil +} + +func convertDeallocate(converter *expressionConverter, v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { + return &stmts.DeallocateStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertExecute(converter *expressionConverter, v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { + oldExec := &stmts.ExecuteStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + } + oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) + for i, val := range v.UsingVars { + oldVar, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldExec.UsingVars[i] = oldVar + } + return oldExec, nil +} + +func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowStmt, error) { + oldShow := &stmts.ShowStmt{ + DBName: v.DBName, + Flag: v.Flag, + Full: v.Full, + GlobalScope: v.GlobalScope, + Text: v.Text(), + } + if v.Table != nil { + oldShow.TableIdent = table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + } + } + if v.Column != nil { + oldShow.ColumnName = joinColumnName(v.Column) + } + if v.Where != nil { + oldWhere, err := convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Where = oldWhere + } + if v.Pattern != nil { + oldPattern, err := convertExpr(converter, v.Pattern) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Pattern = oldPattern.(*expression.PatternLike) + } + switch v.Tp { + case ast.ShowCharset: + oldShow.Target = stmt.ShowCharset + case ast.ShowCollation: + oldShow.Target = stmt.ShowCollation + case ast.ShowColumns: + oldShow.Target = stmt.ShowColumns + case ast.ShowCreateTable: + oldShow.Target = stmt.ShowCreateTable + case ast.ShowDatabases: + oldShow.Target = stmt.ShowDatabases + case ast.ShowTables: + oldShow.Target = stmt.ShowTables + case ast.ShowEngines: + oldShow.Target = stmt.ShowEngines + case ast.ShowVariables: + oldShow.Target = stmt.ShowVariables + case ast.ShowWarnings: + oldShow.Target = stmt.ShowWarnings + case ast.ShowNone: + oldShow.Target = stmt.ShowNone + } + return oldShow, nil +} + +func convertBegin(converter *expressionConverter, v *ast.BeginStmt) (*stmts.BeginStmt, error) { + return &stmts.BeginStmt{ + Text: v.Text(), + }, nil +} + +func convertCommit(converter *expressionConverter, v *ast.CommitStmt) (*stmts.CommitStmt, error) { + return &stmts.CommitStmt{ + Text: v.Text(), + }, nil +} + +func convertRollback(converter *expressionConverter, v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { + return &stmts.RollbackStmt{ + Text: v.Text(), + }, nil +} + +func convertUse(converter *expressionConverter, v *ast.UseStmt) (*stmts.UseStmt, error) { + return &stmts.UseStmt{ + DBName: v.DBName, + Text: v.Text(), + }, nil +} + +func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAssignment) (*stmts.VariableAssignment, error) { + oldValue, err := convertExpr(converter, v.Value) + if err != nil { + return nil, errors.Trace(err) + } + + return &stmts.VariableAssignment{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + Value: oldValue, + Text: v.Text(), + }, nil +} + +func convertSet(converter *expressionConverter, v *ast.SetStmt) (*stmts.SetStmt, error) { + oldSet := &stmts.SetStmt{ + Text: v.Text(), + Variables: make([]*stmts.VariableAssignment, len(v.Variables)), + } + for i, val := range v.Variables { + oldAssign, err := convertVariableAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldSet.Variables[i] = oldAssign + } + return oldSet, nil +} + +func convertSetCharset(converter *expressionConverter, v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { + return &stmts.SetCharsetStmt{ + Charset: v.Charset, + Collate: v.Collate, + Text: v.Text(), + }, nil +} + +func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { + return &stmts.SetPwdStmt{ + User: v.User, + Password: v.Password, + Text: v.Text(), + }, nil +} + +func convertUserSpec(converter *expressionConverter, v *ast.UserSpec) (*coldef.UserSpecification, error) { + oldSpec := &coldef.UserSpecification{ + User: v.User, + } + if v.AuthOpt != nil { + oldAuthOpt := &coldef.AuthOption{ + AuthString: v.AuthOpt.AuthString, + ByAuthString: v.AuthOpt.ByAuthString, + HashString: v.AuthOpt.HashString, + } + oldSpec.AuthOpt = oldAuthOpt + } + return oldSpec, nil +} + +func convertCreateUser(converter *expressionConverter, v *ast.CreateUserStmt) (*stmts.CreateUserStmt, error) { + oldCreateUser := &stmts.CreateUserStmt{ + IfNotExists: v.IfNotExists, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateUser.Specs = append(oldCreateUser.Specs, oldSpec) + } + return oldCreateUser, nil +} + +func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { + oldDo := &stmts.DoStmt{ + Text: v.Text(), + Exprs: make([]expression.Expression, len(v.Exprs)), + } + for i, val := range v.Exprs { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldDo.Exprs[i] = oldExpr + } + return oldDo, nil +} + +func convertPrivElem(converter *expressionConverter, v *ast.PrivElem) (*coldef.PrivElem, error) { + oldPrim := &coldef.PrivElem{ + Priv: v.Priv, + } + for _, val := range v.Cols { + oldPrim.Cols = append(oldPrim.Cols, joinColumnName(val)) + } + return oldPrim, nil +} + +func convertGrant(converter *expressionConverter, v *ast.GrantStmt) (*stmts.GrantStmt, error) { + oldGrant := &stmts.GrantStmt{ + Text: v.Text(), + } + for _, val := range v.Privs { + oldPrim, err := convertPrivElem(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Privs = append(oldGrant.Privs, oldPrim) + } + switch v.ObjectType { + case ast.ObjectTypeNone: + oldGrant.ObjectType = coldef.ObjectTypeNone + case ast.ObjectTypeTable: + oldGrant.ObjectType = coldef.ObjectTypeTable + } + if v.Level != nil { + oldGrantLevel := &coldef.GrantLevel{ + DBName: v.Level.DBName, + TableName: v.Level.TableName, + } + switch v.Level.Level { + case ast.GrantLevelDB: + oldGrantLevel.Level = coldef.GrantLevelDB + case ast.GrantLevelGlobal: + oldGrantLevel.Level = coldef.GrantLevelGlobal + case ast.GrantLevelNone: + oldGrantLevel.Level = coldef.GrantLevelNone + case ast.GrantLevelTable: + oldGrantLevel.Level = coldef.GrantLevelTable + } + oldGrant.Level = oldGrantLevel + } + for _, val := range v.Users { + oldUserSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Users = append(oldGrant.Users, oldUserSpec) + } + return oldGrant, nil +} diff --git a/optimizer/evaluator.go b/optimizer/evaluator.go new file mode 100644 index 0000000000..fffb5b53b3 --- /dev/null +++ b/optimizer/evaluator.go @@ -0,0 +1,317 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/plan" +) + +// Evaluator is a ast visitor that evaluates an expression. +type Evaluator struct { + // columnMap is the map from ColumnName to the position of the rowStack. + // It is used to find the value of the column. + columnMap map[*ast.ColumnName]position + // rowStack is the current row values while scanning. + // It should be updated after scaned a new row. + rowStack []*plan.Row + + // the map from AggregateFuncExpr to aggregator index. + aggregateMap map[*ast.AggregateFuncExpr]int + + // aggregators for the current row + // only outer query aggregate functions are handled. + aggregators []Aggregator + + // when aggregation phase is done, the input is + aggregateDone bool + + err error +} + +type position struct { + stackOffset int + fieldList bool + columnOffset int +} + +// Enter implements ast.Visitor interface. +func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +// Leave implements ast.Visitor interface. +func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + ok = true + case *ast.BetweenExpr: + ok = e.between(v) + case *ast.BinaryOperationExpr: + ok = e.binaryOperation(v) + case *ast.WhenClause: + ok = e.whenClause(v) + case *ast.CaseExpr: + ok = e.caseExpr(v) + case *ast.SubqueryExpr: + ok = e.subquery(v) + case *ast.CompareSubqueryExpr: + ok = e.compareSubquery(v) + case *ast.ColumnNameExpr: + ok = e.columnName(v) + case *ast.DefaultExpr: + ok = e.defaultExpr(v) + case *ast.IdentifierExpr: + ok = e.identifier(v) + case *ast.ExistsSubqueryExpr: + ok = e.existsSubquery(v) + case *ast.PatternInExpr: + ok = e.patternIn(v) + case *ast.IsNullExpr: + ok = e.isNull(v) + case *ast.IsTruthExpr: + ok = e.isTruth(v) + case *ast.PatternLikeExpr: + ok = e.patternLike(v) + case *ast.ParamMarkerExpr: + ok = e.paramMarker(v) + case *ast.ParenthesesExpr: + ok = e.parentheses(v) + case *ast.PositionExpr: + ok = e.position(v) + case *ast.PatternRegexpExpr: + ok = e.patternRegexp(v) + case *ast.RowExpr: + ok = e.row(v) + case *ast.UnaryOperationExpr: + ok = e.unaryOperation(v) + case *ast.ValuesExpr: + ok = e.values(v) + case *ast.VariableExpr: + ok = e.variable(v) + case *ast.FuncCallExpr: + ok = e.funcCall(v) + case *ast.FuncExtractExpr: + ok = e.funcExtract(v) + case *ast.FuncConvertExpr: + ok = e.funcConvert(v) + case *ast.FuncCastExpr: + ok = e.funcCast(v) + case *ast.FuncSubstringExpr: + ok = e.funcSubstring(v) + case *ast.FuncLocateExpr: + ok = e.funcLocate(v) + case *ast.FuncTrimExpr: + ok = e.funcTrim(v) + case *ast.AggregateFuncExpr: + ok = e.aggregateFunc(v) + } + out = in + return +} + +func checkAllOneColumn(exprs ...ast.ExprNode) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *ast.RowExpr: + return false + case *ast.SubqueryExpr: + if len(v.Query.GetResultFields()) != 1 { + return false + } + } + } + return true +} + +func (e *Evaluator) between(v *ast.BetweenExpr) bool { + if !checkAllOneColumn(v.Expr, v.Left, v.Right) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + + var l, r ast.ExprNode + op := opcode.AndAnd + + if v.Not { + // v < lv || v > rv + op = opcode.OrOr + l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right} + } else { + // v >= lv && v <= rv + l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right} + } + + ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r} + ret.Accept(e) + return e.err == nil +} + +func columnCount(e ast.ExprNode) (int, error) { + switch x := e.(type) { + case *ast.RowExpr: + n := len(x.Values) + if n <= 1 { + return 0, errors.Errorf("Operand should contain >= 2 columns for Row") + } + return n, nil + case *ast.SubqueryExpr: + return len(x.Query.GetResultFields()), nil + default: + return 1, nil + } +} + +func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error { + l, err := columnCount(e) + if err != nil { + return errors.Trace(err) + } + var n int + for _, arg := range args { + n, err = columnCount(arg) + if err != nil { + return errors.Trace(err) + } + + if n != l { + return errors.Errorf("Operand should contain %d column(s)", l) + } + } + return nil +} + +func (e *Evaluator) whenClause(v *ast.WhenClause) bool { + return true +} + +func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { + return true +} + +func (e *Evaluator) subquery(v *ast.SubqueryExpr) bool { + return true +} + +func (e *Evaluator) compareSubquery(v *ast.CompareSubqueryExpr) bool { + return true +} + +func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { + return true +} + +func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool { + return true +} + +func (e *Evaluator) identifier(v *ast.IdentifierExpr) bool { + return true +} + +func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { + return true +} + +func (e *Evaluator) patternIn(v *ast.PatternInExpr) bool { + return true +} + +func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { + return true +} + +func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { + return true +} + +func (e *Evaluator) patternLike(v *ast.PatternLikeExpr) bool { + return true +} + +func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { + return true +} + +func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { + return true +} + +func (e *Evaluator) position(v *ast.PositionExpr) bool { + return true +} + +func (e *Evaluator) patternRegexp(v *ast.PatternRegexpExpr) bool { + return true +} + +func (e *Evaluator) row(v *ast.RowExpr) bool { + return true +} + +func (e *Evaluator) unaryOperation(v *ast.UnaryOperationExpr) bool { + return true +} + +func (e *Evaluator) values(v *ast.ValuesExpr) bool { + return true +} + +func (e *Evaluator) variable(v *ast.VariableExpr) bool { + return true +} + +func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool { + return true +} + +func (e *Evaluator) funcExtract(v *ast.FuncExtractExpr) bool { + return true +} + +func (e *Evaluator) funcConvert(v *ast.FuncConvertExpr) bool { + return true +} + +func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool { + return true +} + +func (e *Evaluator) funcSubstring(v *ast.FuncSubstringExpr) bool { + return true +} + +func (e *Evaluator) funcLocate(v *ast.FuncLocateExpr) bool { + return true +} + +func (e *Evaluator) funcTrim(v *ast.FuncTrimExpr) bool { + return true +} + +func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { + idx := e.aggregateMap[v] + aggr := e.aggregators[idx] + if e.aggregateDone { + v.SetValue(aggr.Output()) + return true + } + // TODO: currently only single argument aggregate functions are supported. + e.err = aggr.Input(v.Args[0].GetValue()) + return e.err == nil +} diff --git a/optimizer/evaluator_binop.go b/optimizer/evaluator_binop.go new file mode 100644 index 0000000000..0cda326144 --- /dev/null +++ b/optimizer/evaluator_binop.go @@ -0,0 +1,527 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +const ( + zeroI64 int64 = 0 + oneI64 int64 = 1 +) + +func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool { + // all operands must have same column. + if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { + return false + } + + // row constructor only supports comparison operation. + switch o.Op { + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + default: + if !checkAllOneColumn(o.L) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + } + + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + + switch o.Op { + case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: + return e.handleLogicOperation(o) + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + return e.handleComparisonOp(o) + case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: + // TODO: MySQL doesn't support and not, we should remove it later. + return e.handleBitOp(o) + case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: + return e.handleArithmeticOp(o) + default: + panic("should never happen") + } +} + +func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + var boolVal bool + switch o.Op { + case opcode.AndAnd: + boolVal = leftVal != zeroI64 && rightVal != zeroI64 + case opcode.OrOr: + boolVal = leftVal != zeroI64 || rightVal != zeroI64 + case opcode.LogicXor: + boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) + default: + panic("should never happen") + } + if boolVal { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + if types.IsNil(a) || types.IsNil(b) { + // for <=>, if a and b are both nil, return true. + // if a or b is nil, return false. + if o.Op == opcode.NullEQ { + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + } else { + o.SetValue(nil) + } + return true + } + + n, err := types.Compare(a, b) + if err != nil { + e.err = err + return false + } + + r, err := getCompResult(o.Op, n) + if err != nil { + e.err = err + return false + } + if r { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + case opcode.NullEQ: + return value == 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + +func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(nil) + return true + } + + x, err := types.ToInt64(a) + if err != nil { + e.err = err + return false + } + + y, err := types.ToInt64(b) + if err != nil { + e.err = err + return false + } + + // use a int64 for bit operator, return uint64 + switch o.Op { + case opcode.And: + o.SetValue(uint64(x & y)) + case opcode.Or: + o.SetValue(uint64(x | y)) + case opcode.Xor: + o.SetValue(uint64(x ^ y)) + case opcode.RightShift: + o.SetValue(uint64(x) >> uint64(y)) + case opcode.LeftShift: + o.SetValue(uint64(x) << uint64(y)) + default: + e.err = errors.Errorf("invalid op %v in bit operation", o.Op) + return false + } + return true +} + +func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { + a, err := coerceArithmetic(o.L.GetValue()) + if err != nil { + e.err = err + return false + } + + b, err := coerceArithmetic(o.R.GetValue()) + if err != nil { + e.err = err + return false + } + a, b = types.Coerce(a, b) + + if a == nil || b == nil { + // TODO: for <=>, if a and b are both nil, return true + o.SetValue(nil) + return true + } + + // TODO: support logic division DIV + var result interface{} + switch o.Op { + case opcode.Plus: + result, e.err = computePlus(a, b) + case opcode.Minus: + result, e.err = computeMinus(a, b) + case opcode.Mul: + result, e.err = computeMul(a, b) + case opcode.Div: + result, e.err = computeDiv(a, b) + case opcode.Mod: + result, e.err = computeMod(a, b) + case opcode.IntDiv: + result, e.err = computeIntDiv(a, b) + default: + e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) + return false + } + o.SetValue(result) + return e.err == nil +} + +func computePlus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.AddInt64(x, y) + case uint64: + return types.AddInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.AddInteger(x, y) + case uint64: + return types.AddUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x + y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Add(y), nil + } + } + return types.InvOp2(a, b, opcode.Plus) +} + +func computeMinus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.SubInt64(x, y) + case uint64: + return types.SubIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + return types.SubUintWithInt(x, y) + case uint64: + return types.SubUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x - y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Sub(y), nil + } + } + + return types.InvOp2(a, b, opcode.Minus) +} + +func computeMul(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.MulInt64(x, y) + case uint64: + return types.MulInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.MulInteger(x, y) + case uint64: + return types.MulUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x * y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Mul(y), nil + } + } + + return types.InvOp2(a, b, opcode.Mul) +} + +func computeDiv(a, b interface{}) (interface{}, error) { + // MySQL support integer divison Div and division operator / + // we use opcode.Div for division operator and will use another for integer division later. + // for division operator, we will use float64 for calculation. + switch x := a.(type) { + case float64: + y, err := types.ToFloat64(b) + if err != nil { + return nil, err + } + + if y == 0 { + return nil, nil + } + + return x / y, nil + default: + // the scale of the result is the scale of the first operand plus + // the value of the div_precision_increment system variable (which is 4 by default) + // we will use 4 here + + xa, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + xb, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + if f, _ := xb.Float64(); f == 0 { + // division by zero return null + return nil, nil + } + + return xa.Div(xb), nil + } +} + +func computeMod(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return x % y, nil + case uint64: + if y == 0 { + return nil, nil + } else if x < 0 { + // first is int64, return int64. + return -int64(uint64(-x) % y), nil + } + return int64(uint64(x) % y), nil + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } else if y < 0 { + // first is uint64, return uint64. + return uint64(x % uint64(-y)), nil + } + return x % uint64(y), nil + case uint64: + if y == 0 { + return nil, nil + } + return x % y, nil + } + case float64: + switch y := b.(type) { + case float64: + if y == 0 { + return nil, nil + } + return math.Mod(x, y), nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + xf, _ := x.Float64() + yf, _ := y.Float64() + if yf == 0 { + return nil, nil + } + return math.Mod(xf, yf), nil + } + } + + return types.InvOp2(a, b, opcode.Mod) +} + +func computeIntDiv(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivInt64(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return types.DivIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivUintWithInt(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return x / y, nil + } + } + + // if any is none integer, use decimal to calculate + x, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + y, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + + if f, _ := y.Float64(); f == 0 { + return nil, nil + } + + return x.Div(y).IntPart(), nil +} + +func coerceArithmetic(a interface{}) (interface{}, error) { + switch x := a.(type) { + case string: + // MySQL will convert string to float for arithmetic operation + f, err := types.StrToFloat(x) + if err != nil { + return nil, err + } + return f, err + case mysql.Time: + // if time has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case mysql.Duration: + // if duration has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case []byte: + // []byte is the same as string, converted to float for arithmetic operator. + f, err := types.StrToFloat(string(x)) + if err != nil { + return nil, err + } + return f, err + case mysql.Hex: + return x.ToNumber(), nil + case mysql.Bit: + return x.ToNumber(), nil + case mysql.Enum: + return x.ToNumber(), nil + case mysql.Set: + return x.ToNumber(), nil + default: + return x, nil + } +} diff --git a/optimizer/infobinder.go b/optimizer/infobinder.go new file mode 100644 index 0000000000..0ad59bceb6 --- /dev/null +++ b/optimizer/infobinder.go @@ -0,0 +1,441 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" +) + +// InfoBinder binds schema information for table name and column name and set result fields +// for ResetSetNode. +// We need to know which table a table name refers to, which column a column name refers to. +// +// In general, a reference can only refer to information that are available for it. +// So children elements are visited in the order that previous elements make information +// available for following elements. +// +// During visiting, information are collected and stored in binderContext. +// When we enter a sub query, a new binderContext is pushed to the contextStack, so sub query +// information can overwrite outer query information. When we look up for a column reference, +// we look up from top to bottom in the contextStack. +type InfoBinder struct { + Info infoschema.InfoSchema + DefaultSchema model.CIStr + Err error + + contextStack []*binderContext +} + +// binderContext stores information that table name and column name +// can be bind to. +type binderContext struct { + /* For Select Statement. */ + // table map to lookup and check table name conflict. + tableMap map[string]int + // tableSources collected in from clause. + tables []*ast.TableSource + // result fields collected in select field list. + fieldList []*ast.ResultField + // result fields collected in group by clause. + groupBy []*ast.ResultField + + // The join node stack is used by on condition to find out + // available tables to reference. On condition can only + // refer to tables involved in current join. + joinNodeStack []*ast.Join + + // When visiting TableRefs, tables in this context are not available + // because it is being collected. + inTableRefs bool + // When visiting on conditon only tables in current join node are available. + inOnCondition bool + // When visiting field list, fieldList in this context are not available. + inFieldList bool + // When visiting group by, groupBy fields are not available. + inGroupBy bool + // When visiting having, only fieldList and groupBy fields are available. + inHaving bool +} + +// currentContext gets the current binder context. +func (sb *InfoBinder) currentContext() *binderContext { + stackLen := len(sb.contextStack) + if stackLen == 0 { + return nil + } + return sb.contextStack[stackLen-1] +} + +// pushContext is called when we enter a statement. +func (sb *InfoBinder) pushContext() { + sb.contextStack = append(sb.contextStack, &binderContext{ + tableMap: map[string]int{}, + }) +} + +// popContext is called when we leave a statement. +func (sb *InfoBinder) popContext() { + sb.contextStack = sb.contextStack[:len(sb.contextStack)-1] +} + +// pushJoin is called when we enter a join node. +func (sb *InfoBinder) pushJoin(j *ast.Join) { + ctx := sb.currentContext() + ctx.joinNodeStack = append(ctx.joinNodeStack, j) +} + +// popJoin is called when we leave a join node. +func (sb *InfoBinder) popJoin() { + ctx := sb.currentContext() + ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1] +} + +// Enter implements ast.Visitor interface. +func (sb *InfoBinder) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { + switch v := inNode.(type) { + case *ast.SelectStmt: + sb.pushContext() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = true + case *ast.Join: + sb.pushJoin(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = true + case *ast.FieldList: + sb.currentContext().inFieldList = true + case *ast.GroupByClause: + sb.currentContext().inGroupBy = true + case *ast.HavingClause: + sb.currentContext().inHaving = true + case *ast.InsertStmt: + sb.pushContext() + case *ast.DeleteStmt: + sb.pushContext() + case *ast.UpdateStmt: + sb.pushContext() + } + return inNode, false +} + +// Leave implements ast.Visitor interface. +func (sb *InfoBinder) Leave(inNode ast.Node) (node ast.Node, ok bool) { + switch v := inNode.(type) { + case *ast.TableName: + sb.handleTableName(v) + case *ast.ColumnName: + sb.handleColumnName(v) + case *ast.TableSource: + sb.handleTableSource(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = false + case *ast.Join: + sb.handleJoin(v) + sb.popJoin() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = false + case *ast.FieldList: + sb.handleFieldList(v) + sb.currentContext().inFieldList = false + case *ast.GroupByClause: + sb.currentContext().inGroupBy = false + case *ast.HavingClause: + sb.currentContext().inHaving = false + case *ast.SelectStmt: + v.SetResultFields(sb.currentContext().fieldList) + sb.popContext() + case *ast.InsertStmt: + sb.popContext() + case *ast.DeleteStmt: + sb.popContext() + case *ast.UpdateStmt: + sb.popContext() + } + return inNode, sb.Err == nil +} + +// handleTableName looks up and bind the schema information for table name +// and set result fields for table name. +func (sb *InfoBinder) handleTableName(tn *ast.TableName) { + if tn.Schema.L == "" { + tn.Schema = sb.DefaultSchema + } + table, err := sb.Info.TableByName(tn.Schema, tn.Name) + if err != nil { + sb.Err = err + return + } + tn.TableInfo = table.Meta() + dbInfo, _ := sb.Info.SchemaByName(tn.Schema) + tn.DBInfo = dbInfo + + rfs := make([]*ast.ResultField, len(tn.TableInfo.Columns)) + for i, v := range tn.TableInfo.Columns { + rfs[i] = &ast.ResultField{ + Column: v, + Table: tn.TableInfo, + DBName: tn.Schema, + } + } + tn.SetResultFields(rfs) + return +} + +// handleTableSources checks name duplication +// and puts the table source in current binderContext. +func (sb *InfoBinder) handleTableSource(ts *ast.TableSource) { + for _, v := range ts.GetResultFields() { + v.TableAsName = ts.AsName + } + var name string + if ts.AsName.L != "" { + name = ts.AsName.L + } else { + tableName := ts.Source.(*ast.TableName) + name = sb.tableUniqueName(tableName.Schema, tableName.Name) + } + ctx := sb.currentContext() + if _, ok := ctx.tableMap[name]; ok { + sb.Err = errors.Errorf("duplicated table/alias name %s", name) + return + } + ctx.tableMap[name] = len(ctx.tables) + ctx.tables = append(ctx.tables, ts) + return +} + +// handleJoin sets result fields for join. +func (sb *InfoBinder) handleJoin(j *ast.Join) { + if j.Right == nil { + j.SetResultFields(j.Left.GetResultFields()) + return + } + leftLen := len(j.Left.GetResultFields()) + rightLen := len(j.Right.GetResultFields()) + rfs := make([]*ast.ResultField, leftLen+rightLen) + copy(rfs, j.Left.GetResultFields()) + copy(rfs[leftLen:], j.Right.GetResultFields()) + j.SetResultFields(rfs) +} + +// handleColumnName looks up and binds schema information to +// the column name. +func (sb *InfoBinder) handleColumnName(cn *ast.ColumnName) { + ctx := sb.currentContext() + if ctx.inOnCondition { + // In on condition, only tables within current join is available. + sb.bindColumnNameInOnCondition(cn) + return + } + + // Try to bind the column name form top to bottom in the context stack. + for i := len(sb.contextStack) - 1; i >= 0; i-- { + if sb.bindColumnNameInContext(sb.contextStack[i], cn) { + // Column is already bound or encountered an error. + return + } + } + sb.Err = errors.Errorf("Unknown column %s", cn.Name.L) +} + +// bindColumnNameInContext looks up and binds schema information for a column with the ctx. +func (sb *InfoBinder) bindColumnNameInContext(ctx *binderContext, cn *ast.ColumnName) (done bool) { + if cn.Table.L == "" { + // If qualified table name is not specified in column name, the column name may be ambiguous, + // We need to iterate over all tables and + } + + if ctx.inTableRefs { + // In TableRefsClause, column reference only in join on condition which is handled before. + return false + } + if ctx.inFieldList { + // only bind column using tables. + return sb.bindColumnInTableSources(cn, ctx.tables) + } + if ctx.inGroupBy { + // field list first, then tables. + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + return sb.bindColumnInTableSources(cn, ctx.tables) + } + // column name in other places can be looked up in the same order. + if sb.bindColumnInResultFields(cn, ctx.groupBy) { + return true + } + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + + // tables is not available for having clause. + if !ctx.inHaving { + return sb.bindColumnInTableSources(cn, ctx.tables) + } + return false +} + +// bindColumnNameInOnCondition looks up for column name in current join, and +// binds the schema information. +func (sb *InfoBinder) bindColumnNameInOnCondition(cn *ast.ColumnName) { + ctx := sb.currentContext() + join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1] + tableSources := appendTableSources(nil, join) + if !sb.bindColumnInTableSources(cn, tableSources) { + sb.Err = errors.Errorf("unkown column name %s", cn.Name.O) + } +} + +func (sb *InfoBinder) bindColumnInTableSources(cn *ast.ColumnName, tableSources []*ast.TableSource) (done bool) { + var matchedResultField *ast.ResultField + if cn.Table.L != "" { + var matchedTable ast.ResultSetNode + for _, ts := range tableSources { + if cn.Table.L == ts.AsName.L { + // different table name. + matchedTable = ts + break + } + if tn, ok := ts.Source.(*ast.TableName); ok { + if cn.Table.L == tn.Name.L { + matchedTable = ts + } + } + } + if matchedTable != nil { + resultFields := matchedTable.GetResultFields() + for _, rf := range resultFields { + if rf.ColumnAsName.L == cn.Name.L || rf.Column.Name.L == cn.Name.L { + // bind column. + matchedResultField = rf + break + } + } + } + } else { + for _, ts := range tableSources { + rfs := ts.GetResultFields() + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return true + } + matchedResultField = rf + } + } + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +func (sb *InfoBinder) bindColumnInResultFields(cn *ast.ColumnName, rfs []*ast.ResultField) bool { + var matchedResultField *ast.ResultField + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return false + } + matchedResultField = rf + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +// handleFieldList expands wild card field and set fieldList in current context. +func (sb *InfoBinder) handleFieldList(fieldList *ast.FieldList) { + var resultFields []*ast.ResultField + for _, v := range fieldList.Fields { + resultFields = append(resultFields, sb.createResultFields(v)...) + } + sb.currentContext().fieldList = resultFields +} + +// createResultFields creates result field list for a single select field. +func (sb *InfoBinder) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) { + ctx := sb.currentContext() + if field.WildCard != nil { + if len(ctx.tables) == 0 { + sb.Err = errors.Errorf("No table used.") + return + } + if field.WildCard.Table.L == "" { + for _, v := range ctx.tables { + rfs = append(rfs, v.GetResultFields()...) + } + } else { + name := sb.tableUniqueName(field.WildCard.Schema, field.WildCard.Table) + tableIdx, ok := ctx.tableMap[name] + if !ok { + sb.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O) + } + rfs = ctx.tables[tableIdx].GetResultFields() + } + return + } + // The column is visited before so it must has been bound already. + rf := &ast.ResultField{ColumnAsName: field.AsName} + switch v := field.Expr.(type) { + case *ast.ColumnNameExpr: + rf.Column = v.Name.ColumnInfo + rf.Table = v.Name.TableInfo + rf.DBName = v.Name.Schema + default: + if field.AsName.L == "" { + rf.ColumnAsName.L = field.Expr.Text() + rf.ColumnAsName.O = rf.ColumnAsName.L + } + } + rfs = append(rfs, rf) + return +} + +func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) { + switch v := resultSetNode.(type) { + case *ast.TableSource: + out = append(in, v) + case *ast.Join: + out = appendTableSources(in, v.Left) + if v.Right != nil { + out = appendTableSources(out, v.Right) + } + } + return +} + +func (sb *InfoBinder) tableUniqueName(schema, table model.CIStr) string { + if schema.L != "" && schema.L != sb.DefaultSchema.L { + return schema.L + "." + table.L + } + return table.L +} diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go new file mode 100644 index 0000000000..b14eadb87d --- /dev/null +++ b/optimizer/infobinder_test.go @@ -0,0 +1,82 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer_test + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/testkit" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testInfoBinderSuite{}) + +type testInfoBinderSuite struct { +} + +type binderVerifier struct { + c *C +} + +func (bv *binderVerifier) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (bv *binderVerifier) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ColumnName: + bv.c.Assert(v.ColumnInfo, NotNil) + case *ast.TableName: + bv.c.Assert(v.TableInfo, NotNil) + } + return in, true +} + +func (ts *testInfoBinderSuite) TestInfoBinder(c *C) { + store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory) + c.Assert(err, IsNil) + defer store.Close() + testKit := testkit.NewTestKit(c, store) + testKit.MustExec("use test") + testKit.MustExec("create table t (c1 int, c2 int)") + domain := sessionctx.GetDomain(testKit.Se.(context.Context)) + + src := "SELECT c1 from t" + l := parser.NewLexer(src) + c.Assert(parser.YYParse(l), Equals, 0) + stmts := l.Stmts() + c.Assert(len(stmts), Equals, 1) + v := &optimizer.InfoBinder{ + Info: domain.InfoSchema(), + DefaultSchema: model.NewCIStr("test"), + } + selectStmt := stmts[0].(*ast.SelectStmt) + selectStmt.Accept(v) + + verifier := &binderVerifier{ + c: c, + } + selectStmt.Accept(verifier) +} diff --git a/optimizer/optimizer.go b/optimizer/typecomputer.go similarity index 60% rename from optimizer/optimizer.go rename to optimizer/typecomputer.go index 36ccd06c5f..f2cd35aff0 100644 --- a/optimizer/optimizer.go +++ b/optimizer/typecomputer.go @@ -13,20 +13,18 @@ package optimizer -import ( - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/stmt" -) +import "github.com/pingcap/tidb/ast" -// Compile compiles a ast.Node into a executable statement. -func Compile(node ast.Node) (stmt.Statement, error) { - switch v := node.(type) { - case *ast.SetStmt: - return compileSet(v) - } - return nil, nil +// typeComputer is an ast Visitor that +// computes result type for ast.ExprNode. +type typeComputer struct { + err error } -func compileSet(aset *ast.SetStmt) (stmt.Statement, error) { - return nil, nil +func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true } diff --git a/optimizer/validator.go b/optimizer/validator.go new file mode 100644 index 0000000000..f90a757ab8 --- /dev/null +++ b/optimizer/validator.go @@ -0,0 +1,30 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import "github.com/pingcap/tidb/ast" + +// validator is an ast.Visitor that validates +// ast parsed from parser. +type validator struct { + err error +} + +func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} diff --git a/parser/parser.y b/parser/parser.y index d918b17ee1..1393a7dff1 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -26,22 +26,13 @@ package parser import ( - "fmt" "strings" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/parser/coldef" - "github.com/pingcap/tidb/ddl" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/subquery" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/rset/rsets" - "github.com/pingcap/tidb/stmt" - "github.com/pingcap/tidb/stmt/stmts" - "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) @@ -347,8 +338,8 @@ import ( %type AlterTableStmt "Alter table statement" - AlterSpecification "Alter table specification" - AlterSpecificationList "Alter table specification list" + AlterTableSpec "Alter table specification" + AlterTableSpecList "Alter table specification list" AnyOrAll "Any or All for subquery" Assignment "assignment" AssignmentList "assignment list" @@ -357,10 +348,7 @@ import ( AuthString "Password string value" BeginTransactionStmt "BEGIN TRANSACTION statement" CastType "Cast function target type" - CharsetName "Charset Name" - CollationName "Collation Name" ColumnDef "table column definition" - ColumnListOpt "Optional column list" ColumnName "column name" ColumnNameList "column name list" ColumnNameListOpt "column name list opt" @@ -370,17 +358,18 @@ import ( CommaOpt "optional comma" CommitStmt "COMMIT statement" CompareOp "Compare opcode" - Constraint "column value constraint" - ConstraintElem "table define constraint element" + ColumnOption "column definition option" + ColumnOptionList "column definition option list" + ColumnOptionListOpt "optional column definition option list" + Constraint "table constraint" + ConstraintElem "table constraint element" ConstraintKeywordOpt "Constraint Keyword or empty" - ConstraintOpt "optional column value constraint" - ConstraintOpts "optional column value constraints" CreateDatabaseStmt "Create Database Statement" CreateIndexStmt "CREATE INDEX statement" CreateIndexStmtUnique "CREATE INDEX optional UNIQUE clause" - CreateSpecification "CREATE Database specification" - CreateSpecificationList "CREATE Database specification list" - CreateSpecListOpt "CREATE Database specification list opt" + DatabaseOption "CREATE Database specification" + DatabaseOptionList "CREATE Database specification list" + DatabaseOptionListOpt "CREATE Database specification list opt" CreateTableStmt "CREATE TABLE statement" CreateUserStmt "CREATE User statement" CrossOpt "Cross join option" @@ -414,7 +403,7 @@ import ( FieldAsName "Field alias name" FieldAsNameOpt "Field alias name opt" FieldList "field expression list" - FromClause "From clause" + TableRefsClause "Table references clause" Function "function expr" FunctionCallAgg "Function call on aggregate data" FunctionCallConflict "Function call with reserved keyword as function name" @@ -425,7 +414,6 @@ import ( GlobalScope "The scope of variable" GrantStmt "Grant statement" GroupByClause "GROUP BY clause" - GroupByList "GROUP BY list" HashString "Hashed string" HavingClause "HAVING clause" IfExists "If Exists" @@ -459,11 +447,10 @@ import ( OptInteger "Optional Integer keyword" Order "ORDER BY clause optional collation specification" OrderBy "ORDER BY clause" - OrderByItem "ORDER BY list item" + ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - OrderByList "ORDER BY list" + ByList "BY list" OuterOpt "optional OUTER clause" - QualifiedIdent "qualified identifier" QuickOptional "QUICK or empty" PasswordOpt "Password option" ColumnPosition "Column position [First|After ColumnName]" @@ -489,45 +476,43 @@ import ( SelectStmtLimit "SELECT statement optional LIMIT clause" SelectStmtOpts "Select statement options" SelectStmtGroup "SELECT statement optional GROUP BY clause" - SelectStmtOrder "SELECT statement optional ORDER BY clause" SetStmt "Set variable statement" ShowStmt "Show engines/databases/tables/columns/warnings statement" ShowDatabaseNameOpt "Show tables/columns statement database name option" + ShowTableAliasOpt "Show table alias option" ShowLikeOrWhereOpt "Show like or where condition option" - ShowTableIdentOpt "Show columns statement table name option" SignedLiteral "Literal or NumLiteral with sign" - SimpleQualifiedIdent "Qualified identifier without *" Statement "statement" StatementList "statement list" + StringName "string literal or identifier" StringList "string list" ExplainableStmt "explainable statement" SubSelect "Sub Select" Symbol "Constraint Symbol" SystemVariable "System defined variable name" - TableAsOpt "table as option" - TableConstraint "table constraint definition" + TableAsName "table alias name" + TableAsNameOpt "table alias name optional" TableElement "table definition element" TableElementList "table definition element list" - TableElementListOpt "table definition element list opt" TableFactor "table factor" - TableIdent "Table identifier" - TableIdentList "Table identifier list" - TableIdentOpt "Table identifier option" - TableOpt "create table option" - TableOptList "create table option list" - TableOptListOpt "create table option list opt" + TableName "Table name" + TableNameList "Table name list" + TableOption "create table option" + TableOptionList "create table option list" + TableOptionListOpt "create table option list opt" TableRef "table reference" TableRefs "table references" TimeUnit "Time unit" TrimDirection "Trim string direction" TruncateTableStmt "TRANSACTION TABLE statement" UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionSelect "Union select/(select)" - UnionStmt "Union statement" + UnionStmt "Union select state ment" + UnionClauseList "Union select clause list" + UnionSelect "Union (select) item" UpdateStmt "UPDATE statement" Username "Username" - UserSpecification "Username and auth option" - UserSpecificationList "Username and auth option list" + UserSpec "Username and auth option" + UserSpecList "Username and auth option list" UserVariable "User defined variable name" UserVariableList "User defined variable name list" UseStmt "USE statement" @@ -574,8 +559,6 @@ import ( NUM "numbers" LengthNum "Field length num(uint64)" - - %token tableRefPriority %precedence lowerThanCalcFoundRows @@ -624,70 +607,67 @@ Start: StatementList | parseExpression Expression { - yylex.(*lexer).expr = expression.Expr($2) + yylex.(*lexer).expr = $2.(ast.ExprNode) } /**************************************AlterTableStmt*************************************** * See: https://dev.mysql.com/doc/refman/5.7/en/alter-table.html *******************************************************************************************/ AlterTableStmt: - "ALTER" IgnoreOptional "TABLE" TableIdent AlterSpecificationList + "ALTER" IgnoreOptional "TABLE" TableName AlterTableSpecList { - $$ = &stmts.AlterTableStmt{ - Ident: $4.(table.Ident), - Specs: $5.([]*ddl.AlterSpecification), + $$ = &ast.AlterTableStmt{ + Table: $4.(*ast.TableName), + Specs: $5.([]*ast.AlterTableSpec), } } -AlterSpecification: - TableOptListOpt +AlterTableSpec: + TableOptionListOpt { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterTableOpt, - TableOpts: $1.([]*coldef.TableOpt), + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableOption, + Options:$1.([]*ast.TableOption), } } | "ADD" ColumnKeywordOpt ColumnDef ColumnPosition { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterAddColumn, - Column: $3.(*coldef.ColumnDef), - Position: $4.(*ddl.ColumnPosition), + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableAddColumn, + Column: $3.(*ast.ColumnDef), + Position: $4.(*ast.ColumnPosition), } } -| "ADD" ConstraintKeywordOpt ConstraintElem +| "ADD" Constraint { - constraint := $3.(*coldef.TableConstraint) - if $2 != nil { - constraint.ConstrName = $2.(string) - } - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterAddConstr, + constraint := $2.(*ast.Constraint) + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableAddConstraint, Constraint: constraint, } } | "DROP" ColumnKeywordOpt ColumnName { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterDropColumn, - Name: $3.(string), + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableDropColumn, + DropColumn: $3.(*ast.ColumnName), } } | "DROP" "PRIMARY" "KEY" { - $$ = &ddl.AlterSpecification{Action: ddl.AlterDropPrimaryKey} + $$ = &ast.AlterTableSpec{Tp: ast.AlterTableDropPrimaryKey} } | "DROP" KeyOrIndex IndexName { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterDropIndex, + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableDropIndex, Name: $3.(string), } } | "DROP" "FOREIGN" "KEY" Symbol { - $$ = &ddl.AlterSpecification{ - Action: ddl.AlterDropForeignKey, + $$ = &ast.AlterTableSpec{ + Tp: ast.AlterTableDropForeignKey, Name: $4.(string), } } @@ -701,28 +681,28 @@ ColumnKeywordOpt: ColumnPosition: { - $$ = &ddl.ColumnPosition{Type: ddl.ColumnPositionNone} + $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionNone} } | "FIRST" { - $$ = &ddl.ColumnPosition{Type: ddl.ColumnPositionFirst} + $$ = &ast.ColumnPosition{Tp: ast.ColumnPositionFirst} } | "AFTER" ColumnName { - $$ = &ddl.ColumnPosition{ - Type: ddl.ColumnPositionAfter, - RelativeColumn: $2.(string), + $$ = &ast.ColumnPosition{ + Tp: ast.ColumnPositionAfter, + RelativeColumn: $2.(*ast.ColumnName), } } -AlterSpecificationList: - AlterSpecification +AlterTableSpecList: + AlterTableSpec { - $$ = []*ddl.AlterSpecification{$1.(*ddl.AlterSpecification)} + $$ = []*ast.AlterTableSpec{$1.(*ast.AlterTableSpec)} } -| AlterSpecificationList ',' AlterSpecification +| AlterTableSpecList ',' AlterTableSpec { - $$ = append($1.([]*ddl.AlterSpecification), $3.(*ddl.AlterSpecification)) + $$ = append($1.([]*ast.AlterTableSpec), $3.(*ast.AlterTableSpec)) } ConstraintKeywordOpt: @@ -743,183 +723,210 @@ Symbol: /*******************************************************************************************/ Assignment: - SimpleQualifiedIdent eq Expression + ColumnName eq Expression { - x, err := expression.NewAssignment($1.(string), $3.(expression.Expression)) - if err != nil { - yylex.(*lexer).errf("Parse Assignment error: %s", $1.(string)) - return 1 - } - $$ = x + $$ = &ast.Assignment{Column: $1.(*ast.ColumnName), Expr:$3.(ast.ExprNode)} } AssignmentList: Assignment { - $$ = []*expression.Assignment{$1.(*expression.Assignment)} + $$ = []*ast.Assignment{$1.(*ast.Assignment)} } | AssignmentList ',' Assignment { - $$ = append($1.([]*expression.Assignment), $3.(*expression.Assignment)) + $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) } AssignmentListOpt: /* EMPTY */ { - $$ = []*expression.Assignment{} + $$ = []*ast.Assignment{} } | AssignmentList BeginTransactionStmt: "BEGIN" { - $$ = &stmts.BeginStmt{} + $$ = &ast.BeginStmt{} } | "START" "TRANSACTION" { - $$ = &stmts.BeginStmt{} + $$ = &ast.BeginStmt{} } ColumnDef: - ColumnName Type ConstraintOpts + ColumnName Type ColumnOptionListOpt { - $$ = &coldef.ColumnDef{Name: $1.(string), Tp: $2.(*types.FieldType), Constraints: $3.([]*coldef.ConstraintOpt)} + $$ = &ast.ColumnDef{Name: $1.(*ast.ColumnName), Tp: $2.(*types.FieldType), Options: $3.([]*ast.ColumnOption)} } ColumnName: Identifier + { + $$ = &ast.ColumnName{Name: model.NewCIStr($1.(string))} + } +| Identifier '.' Identifier + { + $$ = &ast.ColumnName{Table: model.NewCIStr($1.(string)), Name: model.NewCIStr($3.(string))} + } +| Identifier '.' Identifier '.' Identifier + { + $$ = &ast.ColumnName{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string)), Name: model.NewCIStr($5.(string))} + } ColumnNameList: ColumnName { - $$ = []string{$1.(string)} + $$ = []*ast.ColumnName{$1.(*ast.ColumnName)} } | ColumnNameList ',' ColumnName { - $$ = append($1.([]string), $3.(string)) + $$ = append($1.([]*ast.ColumnName), $3.(*ast.ColumnName)) } ColumnNameListOpt: /* EMPTY */ { - $$ = []string{} + $$ = []*ast.ColumnName{} } | ColumnNameList + { + $$ = $1.([]*ast.ColumnName) + } CommitStmt: "COMMIT" { - $$ = &stmts.CommitStmt{} + $$ = &ast.CommitStmt{} } -Constraint: +ColumnOption: "NOT" "NULL" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrNotNull, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNotNull} } | "NULL" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrNull, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionNull} } | "AUTO_INCREMENT" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrAutoIncrement, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionAutoIncrement} } | "PRIMARY" "KEY" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrPrimaryKey, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionPrimaryKey} } | "UNIQUE" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrUniq, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniq} } | "UNIQUE" "KEY" { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrUniqKey, Bvalue: true} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionUniqKey} } | "DEFAULT" DefaultValueExpr { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrDefaultValue, Evalue: $2.(expression.Expression)} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionDefaultValue, Expr: $2.(ast.ExprNode)} } | "ON" "UPDATE" NowSym { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrOnUpdate, Evalue: $3.(expression.Expression)} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionOnUpdate, Expr: $3.(ast.ExprNode)} } | "COMMENT" stringLit { - $$ = &coldef.ConstraintOpt{Tp: coldef.ConstrComment} + $$ = &ast.ColumnOption{Tp: ast.ColumnOptionComment} } | "CHECK" '(' Expression ')' { // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html // The CHECK clause is parsed but ignored by all storage engines. - $$ = nil + $$ = &ast.ColumnOption{} + } + +ColumnOptionList: + ColumnOption + { + $$ = []*ast.ColumnOption{$1.(*ast.ColumnOption)} + } +| ColumnOptionList ColumnOption + { + $$ = append($1.([]*ast.ColumnOption), $2.(*ast.ColumnOption)) + } + +ColumnOptionListOpt: + { + $$ = []*ast.ColumnOption{} + } +| ColumnOptionList + { + $$ = $1.([]*ast.ColumnOption) } ConstraintElem: "PRIMARY" "KEY" '(' IndexColNameList ')' { - ce := &coldef.TableConstraint{} - ce.Tp = coldef.ConstrPrimaryKey - ce.Keys = $4.([]*coldef.IndexColName) - $$ = ce + $$ = &ast.Constraint{Tp: ast.ConstraintPrimaryKey, Keys: $4.([]*ast.IndexColName)} } | "FULLTEXT" "KEY" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrFulltext, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintFulltext, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string), + } } | "INDEX" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrIndex, - Keys: $4.([]*coldef.IndexColName), - ConstrName: $2.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintIndex, + Keys: $4.([]*ast.IndexColName), + Name: $2.(string), + } } | "KEY" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrKey, - Keys: $4.([]*coldef.IndexColName), - ConstrName: $2.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintKey, + Keys: $4.([]*ast.IndexColName), + Name: $2.(string)} } | "UNIQUE" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrUniq, - Keys: $4.([]*coldef.IndexColName), - ConstrName: $2.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintUniq, + Keys: $4.([]*ast.IndexColName), + Name: $2.(string)} } | "UNIQUE" "INDEX" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrUniqIndex, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintUniqIndex, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string)} } | "UNIQUE" "KEY" IndexName '(' IndexColNameList ')' { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrUniqKey, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string)} + $$ = &ast.Constraint{ + Tp: ast.ConstraintUniqKey, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string)} } | "FOREIGN" "KEY" IndexName '(' IndexColNameList ')' ReferDef { - $$ = &coldef.TableConstraint{ - Tp: coldef.ConstrForeignKey, - Keys: $5.([]*coldef.IndexColName), - ConstrName: $3.(string), - Refer: $7.(*coldef.ReferenceDef), - } + $$ = &ast.Constraint{ + Tp: ast.ConstraintForeignKey, + Keys: $5.([]*ast.IndexColName), + Name: $3.(string), + Refer: $7.(*ast.ReferenceDef), + } } ReferDef: - "REFERENCES" TableIdent '(' IndexColNameList ')' + "REFERENCES" TableName '(' IndexColNameList ')' { - $$ = &coldef.ReferenceDef{TableIdent: $2.(table.Ident), IndexColNames: $4.([]*coldef.IndexColName)} + $$ = &ast.ReferenceDef{Table: $2.(*ast.TableName), IndexColNames: $4.([]*ast.IndexColName)} } /* @@ -941,7 +948,7 @@ DefaultValueExpr: NowSym: "CURRENT_TIMESTAMP" { - $$ = &expression.Ident{CIStr: model.NewCIStr("CURRENT_TIMESTAMP")} + $$ = &ast.IdentifierExpr{Name: model.NewCIStr("CURRENT_TIMESTAMP")} } | "LOCALTIME" | "LOCALTIMESTAMP" @@ -950,62 +957,31 @@ NowSym: SignedLiteral: Literal { - $$ = expression.Value{Val: $1} + $$ = ast.NewValueExpr($1) } | '+' NumLiteral { - n := expression.Value{Val: $2} - $$ = expression.NewUnaryOperation(opcode.Plus, n) + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: ast.NewValueExpr($2)} } | '-' NumLiteral { - n := expression.Value{Val: $2} - $$ = expression.NewUnaryOperation(opcode.Minus, n) + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr($2)} } // TODO: support decimal literal NumLiteral: intLit | floatLit - -ConstraintOpt: - Constraint - -ConstraintOpts: - { - $$ = []*coldef.ConstraintOpt{} - } -| ConstraintOpts ConstraintOpt - { - if $2 != nil { - $$ = append($1.([]*coldef.ConstraintOpt), $2.(*coldef.ConstraintOpt)) - } else { - $$ = $1 - } - } - CreateIndexStmt: - "CREATE" CreateIndexStmtUnique "INDEX" Identifier "ON" TableIdent '(' IndexColNameList ')' + "CREATE" CreateIndexStmtUnique "INDEX" Identifier "ON" TableName '(' IndexColNameList ')' { - indexName, tableIdent, colNameList := $4.(string), $6.(table.Ident), $8.([]*coldef.IndexColName) - if strings.EqualFold(indexName, tableIdent.Name.O) { - yylex.(*lexer).errf("index name collision: %s", indexName) - return 1 - } - for _, colName := range colNameList { - if indexName == colName.ColumnName { - yylex.(*lexer).errf("index name collision: %s", indexName) - return 1 - } - } - - $$ = &stmts.CreateIndexStmt{ - Unique: $2.(bool), - IndexName: indexName, - TableIdent: tableIdent, - IndexColNames: colNameList, + $$ = &ast.CreateIndexStmt{ + Unique: $2.(bool), + IndexName: $4.(string), + Table: $6.(*ast.TableName), + IndexColNames: $8.([]*ast.IndexColName), } if yylex.(*lexer).root { break @@ -1025,20 +1001,20 @@ IndexColName: ColumnName OptFieldLen Order { //Order is parsed but just ignored as MySQL did - $$ = &coldef.IndexColName{ColumnName: $1.(string), Length: $2.(int)} + $$ = &ast.IndexColName{Column: $1.(*ast.ColumnName), Length: $2.(int)} } IndexColNameList: { - $$ = []*coldef.IndexColName{} + $$ = []*ast.IndexColName{} } | IndexColName { - $$ = []*coldef.IndexColName{$1.(*coldef.IndexColName)} + $$ = []*ast.IndexColName{$1.(*ast.IndexColName)} } | IndexColNameList ',' IndexColName { - $$ = append($1.([]*coldef.IndexColName), $3.(*coldef.IndexColName)) + $$ = append($1.([]*ast.IndexColName), $3.(*ast.IndexColName)) } @@ -1054,30 +1030,14 @@ IndexColNameList: * | [DEFAULT] COLLATE [=] collation_name *******************************************************************/ CreateDatabaseStmt: - "CREATE" DatabaseSym IfNotExists DBName CreateSpecListOpt + "CREATE" DatabaseSym IfNotExists DBName DatabaseOptionListOpt { - opts := $5.([]*coldef.DatabaseOpt) - //compose charset from x - var cs, co string - for _, x := range opts { - switch x.Tp { - case coldef.DBOptCharset, coldef.DBOptCollate: - cs = x.Value - } + $$ = &ast.CreateDatabaseStmt{ + IfNotExists: $3.(bool), + Name: $4.(string), + Options: $5.([]*ast.DatabaseOption), } - ok := charset.ValidCharsetAndCollation(cs, co) - if !ok { - yylex.(*lexer).errf("Unknown character set %s or collate %s ", cs, co) - return 1 - } - dbopt := &coldef.CharsetOpt{Chs: cs, Col: co} - - $$ = &stmts.CreateDatabaseStmt{ - IfNotExists: $3.(bool), - Name: $4.(string), - Opt: dbopt} - if yylex.(*lexer).root { break } @@ -1086,55 +1046,30 @@ CreateDatabaseStmt: DBName: Identifier -CreateSpecification: - DefaultKwdOpt CharsetKw EqOpt CharsetName +DatabaseOption: + DefaultKwdOpt CharsetKw EqOpt StringName { - $$ = &coldef.DatabaseOpt{Tp: coldef.DBOptCollate, Value: $4.(string)} + $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCharset, Value: $4.(string)} } -| DefaultKwdOpt "COLLATE" EqOpt CollationName +| DefaultKwdOpt "COLLATE" EqOpt StringName { - $$ = &coldef.DatabaseOpt{Tp: coldef.DBOptCollate, Value: $4.(string)} + $$ = &ast.DatabaseOption{Tp: ast.DatabaseOptionCollate, Value: $4.(string)} } -CharsetName: - Identifier +DatabaseOptionListOpt: { - c := strings.ToLower($1.(string)) - if charset.ValidCharsetAndCollation(c, "") { - $$ = c - } else { - yylex.(*lexer).errf("Unknown character set: '%s'", $1.(string)) - return 1 - } - } -| stringLit - { - c := strings.ToLower($1.(string)) - if charset.ValidCharsetAndCollation(c, "") { - $$ = c - } else { - yylex.(*lexer).errf("Unknown character set: '%s'", $1.(string)) - return 1 - } + $$ = []*ast.DatabaseOption{} } +| DatabaseOptionList -CollationName: - Identifier - -CreateSpecListOpt: +DatabaseOptionList: + DatabaseOption { - $$ = []*coldef.DatabaseOpt{} + $$ = []*ast.DatabaseOption{$1.(*ast.DatabaseOption)} } -| CreateSpecificationList - -CreateSpecificationList: - CreateSpecification +| DatabaseOptionList DatabaseOption { - $$ = []*coldef.DatabaseOpt{$1.(*coldef.DatabaseOpt)} - } -| CreateSpecificationList CreateSpecification - { - $$ = append($1.([]*coldef.DatabaseOpt), $2.(*coldef.DatabaseOpt)) + $$ = append($1.([]*ast.DatabaseOption), $2.(*ast.DatabaseOption)) } /******************************************************************* @@ -1153,45 +1088,30 @@ CreateSpecificationList: * ) *******************************************************************/ CreateTableStmt: - "CREATE" "TABLE" IfNotExists TableIdent '(' TableElementListOpt ')' TableOptListOpt + "CREATE" "TABLE" IfNotExists TableName '(' TableElementList ')' TableOptionListOpt { tes := $6.([]interface {}) - var columnDefs []*coldef.ColumnDef - var tableConstraints []*coldef.TableConstraint + var columnDefs []*ast.ColumnDef + var constraints []*ast.Constraint for _, te := range tes { switch te := te.(type) { - case *coldef.ColumnDef: + case *ast.ColumnDef: columnDefs = append(columnDefs, te) - case *coldef.TableConstraint: - tableConstraints = append(tableConstraints, te) + case *ast.Constraint: + constraints = append(constraints, te) } } if len(columnDefs) == 0 { yylex.(*lexer).err("Column Definition List can't be empty.") return 1 } - - opt := &coldef.TableOption{} - if $8 != nil { - for _, o := range $8.([]*coldef.TableOpt) { - switch o.Tp { - case coldef.TblOptEngine: - opt.Engine = o.StrValue - case coldef.TblOptCharset: - opt.Charset = o.StrValue - case coldef.TblOptCollate: - opt.Collate = o.StrValue - case coldef.TblOptAutoIncrement: - opt.AutoIncrement = o.UintValue - } - } - } - $$ = &stmts.CreateTableStmt{ - Ident: $4.(table.Ident), + $$ = &ast.CreateTableStmt{ + Table: $4.(*ast.TableName), IfNotExists: $3.(bool), Cols: columnDefs, - Constraints: tableConstraints, - Opt: opt} + Constraints: constraints, + Options: $8.([]*ast.TableOption), + } } Default: @@ -1217,8 +1137,8 @@ DefaultKwdOpt: DoStmt: "DO" ExpressionList { - $$ = &stmts.DoStmt { - Exprs: $2.([]expression.Expression), + $$ = &ast.DoStmt { + Exprs: $2.([]ast.ExprNode), } } @@ -1228,27 +1148,24 @@ DoStmt: * *******************************************************************/ DeleteFromStmt: - "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableIdent WhereClauseOptional OrderByOptional LimitClause + "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause { // Single Table - ts := &rsets.TableSource{Source: $6} - r := &rsets.JoinRset{Left: ts, Right: nil} - x := &stmts.DeleteStmt{ - Refs: r, + join := &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil} + x := &ast.DeleteStmt{ + TableRefs: &ast.TableRefsClause{TableRefs: join}, LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), } if $7 != nil { - x.Where = $7.(expression.Expression) + x.Where = $7.(ast.ExprNode) } - if $8 != nil { - x.Order = $8.(*rsets.OrderByRset) + x.Order = $8.(*ast.OrderByClause) } - if $9 != nil { - x.Limit = $9.(*rsets.LimitRset) + x.Limit = $9.(*ast.Limit) } $$ = x @@ -1256,39 +1173,39 @@ DeleteFromStmt: break } } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional TableIdentList "FROM" TableRefs WhereClauseOptional +| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional TableNameList "FROM" TableRefs WhereClauseOptional { // Multiple Table - x := &stmts.DeleteStmt{ + x := &ast.DeleteStmt{ LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), MultiTable: true, BeforeFrom: true, - TableIdents: $5.([]table.Ident), - Refs: $7.(*rsets.JoinRset), + Tables: $5.([]*ast.TableName), + TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, } if $8 != nil { - x.Where = $8.(expression.Expression) + x.Where = $8.(ast.ExprNode) } $$ = x if yylex.(*lexer).root { break } } -| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableIdentList "USING" TableRefs WhereClauseOptional +| "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableNameList "USING" TableRefs WhereClauseOptional { // Multiple Table - x := &stmts.DeleteStmt{ + x := &ast.DeleteStmt{ LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), MultiTable: true, - TableIdents: $6.([]table.Ident), - Refs: $8.(*rsets.JoinRset), + Tables: $6.([]*ast.TableName), + TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, } if $9 != nil { - x.Where = $9.(expression.Expression) + x.Where = $9.(ast.ExprNode) } $$ = x if yylex.(*lexer).root { @@ -1302,29 +1219,29 @@ DatabaseSym: DropDatabaseStmt: "DROP" DatabaseSym IfExists DBName { - $$ = &stmts.DropDatabaseStmt{IfExists: $3.(bool), Name: $4.(string)} + $$ = &ast.DropDatabaseStmt{IfExists: $3.(bool), Name: $4.(string)} if yylex.(*lexer).root { break } } DropIndexStmt: - "DROP" "INDEX" IfExists Identifier + "DROP" "INDEX" IfExists Identifier "ON" TableName { - $$ = &stmts.DropIndexStmt{IfExists: $3.(bool), IndexName: $4.(string)} + $$ = &ast.DropIndexStmt{IfExists: $3.(bool), IndexName: $4.(string), Table: $6.(*ast.TableName)} } DropTableStmt: - "DROP" TableOrTables TableIdentList + "DROP" TableOrTables TableNameList { - $$ = &stmts.DropTableStmt{TableIdents: $3.([]table.Ident)} + $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} if yylex.(*lexer).root { break } } -| "DROP" "TABLE" "IF" "EXISTS" TableIdentList +| "DROP" TableOrTables "IF" "EXISTS" TableNameList { - $$ = &stmts.DropTableStmt{IfExists: true, TableIdents: $5.([]table.Ident)} + $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} if yylex.(*lexer).root { break } @@ -1353,26 +1270,28 @@ ExplainSym: | "DESC" ExplainStmt: - ExplainSym TableIdent + ExplainSym TableName { - $$ = &stmts.ExplainStmt{ - S:&stmts.ShowStmt{ - Target: stmt.ShowColumns, - TableIdent: $2.(table.Ident)}, + $$ = &ast.ExplainStmt{ + Stmt: &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $2.(*ast.TableName), + }, } } -| ExplainSym TableIdent ColumnName +| ExplainSym TableName ColumnName { - $$ = &stmts.ExplainStmt{ - S:&stmts.ShowStmt{ - Target: stmt.ShowColumns, - TableIdent: $2.(table.Ident), - ColumnName: $3.(string)}, + $$ = &ast.ExplainStmt{ + Stmt: &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $2.(*ast.TableName), + Column: $3.(*ast.ColumnName), + }, } } | ExplainSym ExplainableStmt { - $$ = &stmts.ExplainStmt{S:$2.(stmt.Statement)} + $$ = &ast.ExplainStmt{Stmt: $2.(ast.StmtNode)} } LengthNum: @@ -1392,32 +1311,32 @@ NUM: Expression: Expression logOr Expression %prec oror { - $$ = expression.NewBinaryOperation(opcode.OrOr, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.OrOr, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | Expression "XOR" Expression %prec xor { - $$ = expression.NewBinaryOperation(opcode.LogicXor, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.LogicXor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | Expression logAnd Expression %prec andand { - $$ = expression.NewBinaryOperation(opcode.AndAnd, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.AndAnd, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | "NOT" Expression %prec not { - $$ = expression.NewUnaryOperation(opcode.Not, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} } | Factor "IS" NotOpt trueKwd %prec is { - $$ = &expression.IsTruth{Expr:$1.(expression.Expression), Not: $3.(bool), True: int64(1)} + $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(1)} } | Factor "IS" NotOpt falseKwd %prec is { - $$ = &expression.IsTruth{Expr:$1.(expression.Expression), Not: $3.(bool), True: int64(0)} + $$ = &ast.IsTruthExpr{Expr:$1.(ast.ExprNode), Not: $3.(bool), True: int64(0)} } | Factor "IS" NotOpt "UNKNOWN" %prec is { /* https://dev.mysql.com/doc/refman/5.7/en/comparison-operators.html#operator_is */ - $$ = &expression.IsNull{Expr: $1.(expression.Expression), Not: $3.(bool)} + $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} } | Factor @@ -1444,31 +1363,31 @@ name: ExpressionList: Expression { - $$ = []expression.Expression{expression.Expr($1)} + $$ = []ast.ExprNode{$1.(ast.ExprNode)} } | ExpressionList ',' Expression { - $$ = append($1.([]expression.Expression), expression.Expr($3)) + $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) } ExpressionListOpt: { - $$ = []expression.Expression{} + $$ = []ast.ExprNode{} } | ExpressionList Factor: Factor "IS" NotOpt "NULL" %prec is { - $$ = &expression.IsNull{Expr: $1.(expression.Expression), Not: $3.(bool)} + $$ = &ast.IsNullExpr{Expr: $1.(ast.ExprNode), Not: $3.(bool)} } | Factor CompareOp PredicateExpr %prec eq { - $$ = expression.NewBinaryOperation($2.(opcode.Op), $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | Factor CompareOp AnyOrAll SubSelect %prec eq { - $$ = expression.NewCompareSubQuery($2.(opcode.Op), $1.(expression.Expression), $4.(*subquery.SubQuery), $3.(bool)) + $$ = &ast.CompareSubqueryExpr{Op: $2.(opcode.Op), L: $1.(ast.ExprNode), R: $4.(*ast.SubqueryExpr), All: $3.(bool)} } | PredicateExpr @@ -1523,19 +1442,19 @@ AnyOrAll: PredicateExpr: PrimaryFactor NotOpt "IN" '(' ExpressionList ')' { - $$ = &expression.PatternIn{Expr: $1.(expression.Expression), Not: $2.(bool), List: $5.([]expression.Expression)} + $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), List: $5.([]ast.ExprNode)} } | PrimaryFactor NotOpt "IN" SubSelect { - $$ = &expression.PatternIn{Expr: $1.(expression.Expression), Not: $2.(bool), Sel: $4.(*subquery.SubQuery)} + $$ = &ast.PatternInExpr{Expr: $1.(ast.ExprNode), Not: $2.(bool), Sel: $4.(*ast.SubqueryExpr)} } | PrimaryFactor NotOpt "BETWEEN" PrimaryFactor "AND" PredicateExpr { - var err error - $$, err = expression.NewBetween($1.(expression.Expression), $4.(expression.Expression), $6.(expression.Expression), $2.(bool)) - if err != nil { - yylex.(*lexer).err(err) - return 1 + $$ = &ast.BetweenExpr{ + Expr: $1.(ast.ExprNode), + Left: $4.(ast.ExprNode), + Right: $6.(ast.ExprNode), + Not: $2.(bool), } } | PrimaryFactor NotOpt "LIKE" PrimaryExpression LikeEscapeOpt @@ -1547,15 +1466,16 @@ PredicateExpr: } else if len(escape) == 0 { escape = "\\" } - $$ = &expression.PatternLike{ - Expr: $1.(expression.Expression), - Pattern: $4.(expression.Expression), - Not: $2.(bool), - Escape: escape[0]} + $$ = &ast.PatternLikeExpr{ + Expr: $1.(ast.ExprNode), + Pattern: $4.(ast.ExprNode), + Not: $2.(bool), + Escape: escape[0], + } } | PrimaryFactor NotOpt RegexpSym PrimaryExpression { - $$ = &expression.PatternRegexp{Expr: $1.(expression.Expression), Pattern: $4.(expression.Expression), Not: $2.(bool)} + $$ = &ast.PatternRegexpExpr{Expr: $1.(ast.ExprNode), Pattern: $4.(ast.ExprNode), Not: $2.(bool)} } | PrimaryFactor @@ -1585,12 +1505,29 @@ NotOpt: Field: '*' { - $$ = &field.Field{Expr: &expression.Ident{CIStr: model.NewCIStr("*")}} + $$ = &ast.SelectField{WildCard: &ast.WildCardField{}} + } +| Identifier '.' '*' + { + wildCard := &ast.WildCardField{Table: model.NewCIStr($1.(string))} + $$ = &ast.SelectField{WildCard: wildCard} + } +| Identifier '.' Identifier '.' '*' + { + wildCard := &ast.WildCardField{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Expression FieldAsNameOpt { - expr, name := expression.Expr($1), $2.(string) - $$ = &field.Field{Expr: expr, AsName: name} + expr := $1.(ast.ExprNode) + asName := $2.(string) + if asName != "" { + // Set expr original text. + offset := yyS[yypt-1].offset + end := yyS[yypt].offset-1 + expr.SetText(yylex.(*lexer).src[offset:end]) + } + $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} } FieldAsNameOpt: @@ -1624,27 +1561,28 @@ FieldAsName: FieldList: Field { - $$ = []*field.Field{$1.(*field.Field)} + field := $1.(*ast.SelectField) + field.Offset = yyS[yypt].offset + $$ = []*ast.SelectField{field} } | FieldList ',' Field { - $$ = append($1.([]*field.Field), $3.(*field.Field)) + + fl := $1.([]*ast.SelectField) + last := fl[len(fl)-1] + if last.Expr != nil && last.AsName.O == "" { + lastEnd := yyS[yypt-1].offset-1 // Comma offset. + last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) + } + newField := $3.(*ast.SelectField) + newField.Offset = yyS[yypt].offset + $$ = append(fl, newField) } GroupByClause: - "GROUP" "BY" GroupByList + "GROUP" "BY" ByList { - $$ = &rsets.GroupByRset{By: $3.([]expression.Expression)} - } - -GroupByList: - Expression - { - $$ = []expression.Expression{$1.(expression.Expression)} - } -| GroupByList ',' Expression - { - $$ = append($1.([]expression.Expression), $3.(expression.Expression)) + $$ = &ast.GroupByClause{Items: $3.([]*ast.ByItem)} } HavingClause: @@ -1653,7 +1591,7 @@ HavingClause: } | "HAVING" Expression { - $$ = &rsets.HavingRset{Expr:$2.(expression.Expression)} + $$ = &ast.HavingClause{Expr: $2.(ast.ExprNode)} } IfExists: @@ -1726,13 +1664,15 @@ NotKeywordToken: * TODO: support PARTITION **********************************************************************************/ InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableIdent InsertValues OnDuplicateKeyUpdate + "INSERT" Priority IgnoreOptional IntoOpt TableName InsertValues OnDuplicateKeyUpdate { - x := &stmts.InsertIntoStmt{InsertValues: $6.(stmts.InsertValues)} + x := $6.(*ast.InsertStmt) x.Priority = $2.(int) - x.TableIdent = $5.(table.Ident) + // Wraps many layers here so that it can be processed the same way as select statement. + ts := &ast.TableSource{Source: $5.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} if $7 != nil { - x.OnDuplicate = $7.([]*expression.Assignment) + x.OnDuplicate = $7.([]*ast.Assignment) } $$ = x if yylex.(*lexer).root { @@ -1750,33 +1690,34 @@ IntoOpt: InsertValues: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { - $$ = stmts.InsertValues{ - ColNames: $2.([]string), - Lists: $5.([][]expression.Expression)} + $$ = &ast.InsertStmt{ + Columns: $2.([]*ast.ColumnName), + Lists: $5.([][]ast.ExprNode), + } } | '(' ColumnNameListOpt ')' SelectStmt { - $$ = stmts.InsertValues{ColNames: $2.([]string), Sel: $4.(*stmts.SelectStmt)} + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} } | '(' ColumnNameListOpt ')' UnionStmt { - $$ = stmts.InsertValues{ColNames: $2.([]string), Sel: $4.(*stmts.UnionStmt)} + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.UnionStmt)} } | ValueSym ExpressionListList %prec insertValues { - $$ = stmts.InsertValues{Lists: $2.([][]expression.Expression)} + $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} } | SelectStmt { - $$ = stmts.InsertValues{Sel: $1.(*stmts.SelectStmt)} + $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} } | UnionStmt { - $$ = stmts.InsertValues{Sel: $1.(*stmts.UnionStmt)} + $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} } | "SET" ColumnSetValueList { - $$ = stmts.InsertValues{Setlist: $2.([]*expression.Assignment)} + $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} } ValueSym: @@ -1786,40 +1727,41 @@ ValueSym: ExpressionListList: '(' ')' { - $$ = [][]expression.Expression{[]expression.Expression{}} + $$ = [][]ast.ExprNode{[]ast.ExprNode{}} } | '(' ')' ',' ExpressionListList { - $$ = append([][]expression.Expression{[]expression.Expression{}}, $4.([][]expression.Expression)...) + $$ = append([][]ast.ExprNode{[]ast.ExprNode{}}, $4.([][]ast.ExprNode)...) } | '(' ExpressionList ')' { - $$ = [][]expression.Expression{$2.([]expression.Expression)} + $$ = [][]ast.ExprNode{$2.([]ast.ExprNode)} } | '(' ExpressionList ')' ',' ExpressionListList { - $$ = append([][]expression.Expression{$2.([]expression.Expression)}, $5.([][]expression.Expression)...) + $$ = append([][]ast.ExprNode{$2.([]ast.ExprNode)}, $5.([][]ast.ExprNode)...) } ColumnSetValue: ColumnName eq Expression { - $$ = &expression.Assignment{ - ColName: $1.(string), - Expr: expression.Expr($3)} + $$ = &ast.Assignment{ + Column: $1.(*ast.ColumnName), + Expr: $3.(ast.ExprNode), + } } ColumnSetValueList: { - $$ = []*expression.Assignment{} + $$ = []*ast.Assignment{} } | ColumnSetValue { - $$ = []*expression.Assignment{$1.(*expression.Assignment)} + $$ = []*ast.Assignment{$1.(*ast.Assignment)} } | ColumnSetValueList ',' ColumnSetValue { - $$ = append($1.([]*expression.Assignment), $3.(*expression.Assignment)) + $$ = append($1.([]*ast.Assignment), $3.(*ast.Assignment)) } /* @@ -1835,34 +1777,40 @@ OnDuplicateKeyUpdate: $$ = $5 } +/***********************************Insert Statements END************************************/ + /************************************************************************************ - * Replace Statments + * Replace Statements * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html * * TODO: support PARTITION **********************************************************************************/ ReplaceIntoStmt: - "REPLACE" ReplacePriority IntoOpt TableIdent InsertValues + "REPLACE" ReplacePriority IntoOpt TableName InsertValues { - x := &stmts.ReplaceIntoStmt{InsertValues: $5.(stmts.InsertValues)} + x := $5.(*ast.InsertStmt) + x.Replace = true x.Priority = $2.(int) - x.TableIdent = $4.(table.Ident) + ts := &ast.TableSource{Source: $4.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} $$ = x } ReplacePriority: { - $$ = stmts.NoPriority + $$ = ast.NoPriority } | "LOW_PRIORITY" { - $$ = stmts.LowPriority + $$ = ast.LowPriority } | "DELAYED" { - $$ = stmts.DelayedPriority + $$ = ast.DelayedPriority } +/***********************************Replace Statments END************************************/ + Literal: "false" { @@ -1908,23 +1856,23 @@ Literal: Operand: Literal { - $$ = expression.Value{Val: $1} + $$ = ast.NewValueExpr($1) } -| QualifiedIdent +| ColumnName { - $$ = &expression.Ident{CIStr: model.NewCIStr($1.(string))} + $$ = &ast.ColumnNameExpr{Name: $1.(*ast.ColumnName)} } | '(' Expression ')' { - $$ = &expression.PExpr{Expr: expression.Expr($2)} + $$ = &ast.ParenthesesExpr{Expr: $2.(ast.ExprNode)} } | "DEFAULT" %prec lowerThanLeftParen { - $$ = &expression.Default{} + $$ = &ast.DefaultExpr{} } | "DEFAULT" '(' ColumnName ')' { - $$ = &expression.Default{Name: $3.(string)} + $$ = &ast.DefaultExpr{Name: $3.(*ast.ColumnName)} } | Variable { @@ -1932,64 +1880,59 @@ Operand: } | "PLACEHOLDER" { - l := yylex.(*lexer) - if !l.prepare { - l.err("Can not accept placeholder when not parsing prepare sql") - return 1 + $$ = &ast.ParamMarkerExpr{ + Offset: yyS[yypt].offset, } - pm := &expression.ParamMarker{} - l.ParamList = append(l.ParamList, pm) - $$ = pm } | "ROW" '(' Expression ',' ExpressionList ')' { - values := append([]expression.Expression{$3.(expression.Expression)}, $5.([]expression.Expression)...) - $$ = &expression.Row{Values: values} + values := append([]ast.ExprNode{$3.(ast.ExprNode)}, $5.([]ast.ExprNode)...) + $$ = &ast.RowExpr{Values: values} } | '(' Expression ',' ExpressionList ')' { - values := append([]expression.Expression{$2.(expression.Expression)}, $4.([]expression.Expression)...) - $$ = &expression.Row{Values: values} + values := append([]ast.ExprNode{$2.(ast.ExprNode)}, $4.([]ast.ExprNode)...) + $$ = &ast.RowExpr{Values: values} } | "EXISTS" SubSelect { - $$ = &expression.ExistsSubQuery{Sel: $2.(*subquery.SubQuery)} + $$ = &ast.ExistsSubqueryExpr{Sel: $2.(*ast.SubqueryExpr)} } OrderBy: - "ORDER" "BY" OrderByList + "ORDER" "BY" ByList { - $$ = &rsets.OrderByRset{By: $3.([]rsets.OrderByItem)} + $$ = &ast.OrderByClause{Items: $3.([]*ast.ByItem)} } -OrderByList: - OrderByItem +ByList: + ByItem { - $$ = []rsets.OrderByItem{$1.(rsets.OrderByItem)} + $$ = []*ast.ByItem{$1.(*ast.ByItem)} } -| OrderByList ',' OrderByItem +| ByList ',' ByItem { - $$ = append($1.([]rsets.OrderByItem), $3.(rsets.OrderByItem)) + $$ = append($1.([]*ast.ByItem), $3.(*ast.ByItem)) } -OrderByItem: +ByItem: Expression Order { - $$ = rsets.OrderByItem{Expr: $1.(expression.Expression), Asc: $2.(bool)} + $$ = &ast.ByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} } Order: /* EMPTY */ { - $$ = true // ASC by default + $$ = false // ASC by default } | "ASC" { - $$ = true + $$ = false } | "DESC" { - $$ = false + $$ = true } OrderByOptional: @@ -2007,19 +1950,19 @@ PrimaryExpression: | SubSelect | '!' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.Not, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Not, V: $2.(ast.ExprNode)} } | '~' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.BitNeg, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.BitNeg, V: $2.(ast.ExprNode)} } | '-' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.Minus, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: $2.(ast.ExprNode)} } | '+' PrimaryExpression %prec neg { - $$ = expression.NewUnaryOperation(opcode.Plus, $2.(expression.Expression)) + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: $2.(ast.ExprNode)} } | "BINARY" PrimaryExpression %prec neg { @@ -2027,13 +1970,13 @@ PrimaryExpression: x := types.NewFieldType(mysql.TypeString) x.Charset = charset.CharsetBin x.Collate = charset.CharsetBin - $$ = &expression.FunctionCast{ - Expr: $2.(expression.Expression), + $$ = &ast.FuncCastExpr{ + Expr: $2.(ast.ExprNode), Tp: x, - FunctionType: expression.BinaryOperator, + FunctionType: ast.CastBinaryOperator, } } -| PrimaryExpression "COLLATE" CollationName %prec neg +| PrimaryExpression "COLLATE" StringName %prec neg { // TODO: Create a builtin function hold expr and collation. When do evaluation, convert expr result using the collation. $$ = $1 @@ -2051,34 +1994,16 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - x := yylex.(*lexer) - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - x.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - x := yylex.(*lexer) - var err error - $$, err = expression.NewCall($1.(string), []expression.Expression{}, false) - if err != nil { - x.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_DATE" { - x := yylex.(*lexer) - var err error - $$, err = expression.NewCall($1.(string), []expression.Expression{}, false) - if err != nil { - x.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } DistinctOpt: @@ -2099,604 +2024,346 @@ DistinctOpt: } FunctionCallKeyword: - "AVG" '(' DistinctOpt ExpressionList ')' - { - var err error - $$, err = expression.NewCall($1.(string), $4.([]expression.Expression), $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } - } -| "CAST" '(' Expression "AS" CastType ')' + "CAST" '(' Expression "AS" CastType ')' { /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ - $$ = &expression.FunctionCast{ - Expr: $3.(expression.Expression), + $$ = &ast.FuncCastExpr{ + Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), - FunctionType: expression.CastFunction, - } + FunctionType: ast.CastFunction, + } } | "CASE" ExpressionOpt WhenClauseList ElseOpt "END" { - x := &expression.FunctionCase{WhenClauses: $3.([]*expression.WhenClause)} + x := &ast.CaseExpr{WhenClauses: $3.([]*ast.WhenClause)} if $2 != nil { - x.Value = $2.(expression.Expression) + x.Value = $2.(ast.ExprNode) } if $4 != nil { - x.ElseClause = $4.(expression.Expression) + x.ElseClause = $4.(ast.ExprNode) } $$ = x } -| "CONVERT" '(' Expression "USING" CharsetName ')' +| "CONVERT" '(' Expression "USING" StringName ')' { // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &expression.FunctionConvert{ - Expr: $3.(expression.Expression), + $$ = &ast.FuncConvertExpr{ + Expr: $3.(ast.ExprNode), Charset: $5.(string), } } -| "CONVERT" '(' Expression ',' CastType ')' +| "CONVERT" '(' Expression ',' CastType ')' { // See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_convert - $$ = &expression.FunctionCast{ - Expr: $3.(expression.Expression), + $$ = &ast.FuncCastExpr{ + Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), - FunctionType: expression.ConvertFunction, + FunctionType: ast.CastConvertFunction, } } | "DATE" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "USER" '(' ')' { - args := []expression.Expression{} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } -| "VALUES" '(' Identifier ')' %prec lowerThanInsertValues +| "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { // TODO: support qualified identifier for column_name - $$ = &expression.Values{CIStr: model.NewCIStr($3.(string))} + $$ = &ast.ValuesExpr{Column: $3.(*ast.ColumnName)} } | "WEEK" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "YEAR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: "COALESCE" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURDATE" '(' ')' { - var err error - $$, err = expression.NewCall($1.(string), []expression.Expression{}, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_TIMESTAMP" FuncDatetimePrec { - args := []expression.Expression{} + args := []ast.ExprNode{} if $2 != nil { - args = append(args, $2.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $2.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "ABS" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "CONCAT" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CONCAT_WS" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "DAY" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFWEEK" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFMONTH" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFYEAR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { - $$ = &expression.DateArith{ - Op:expression.DateAdd, - Unit: $7.(string), - Date: $3.(expression.Expression), - Interval: $6.(expression.Expression), + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateAdd, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), } } | "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { - $$ = &expression.DateArith{ - Op:expression.DateSub, + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateSub, Unit: $7.(string), - Date: $3.(expression.Expression), - Interval: $6.(expression.Expression), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), } } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { - $$ = &expression.Extract{ - Unit: $3.(string), - Date: $5.(expression.Expression), + $$ = &ast.FuncExtractExpr{ + Unit: $3.(string), + Date: $5.(ast.ExprNode), } } | "FOUND_ROWS" '(' ')' { - args := []expression.Expression{} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "HOUR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "IFNULL" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "LOCATE" '(' Expression ',' Expression ')' { - $$ = &expression.FunctionLocate{ - SubStr: $3.(expression.Expression), - Str: $5.(expression.Expression), + $$ = &ast.FuncLocateExpr{ + SubStr: $3.(ast.ExprNode), + Str: $5.(ast.ExprNode), } } | "LOCATE" '(' Expression ',' Expression ',' Expression ')' { - $$ = &expression.FunctionLocate{ - SubStr: $3.(expression.Expression), - Str: $5.(expression.Expression), - Pos: $7.(expression.Expression), + $$ = &ast.FuncLocateExpr{ + SubStr: $3.(ast.ExprNode), + Str: $5.(ast.ExprNode), + Pos: $7.(ast.ExprNode), } } | "LOWER" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MICROSECOND" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MINUTE" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MONTH" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "NOW" '(' ExpressionOpt ')' { - args := []expression.Expression{} + args := []ast.ExprNode{} if $3 != nil { - args = append(args, $3.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $3.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "NULLIF" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression), false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "RAND" '(' ExpressionOpt ')' { - args := []expression.Expression{} + args := []ast.ExprNode{} if $3 != nil { - args = append(args, $3.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $3.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "REPLACE" '(' Expression ',' Expression ',' Expression ')' { - args := []expression.Expression{$3.(expression.Expression), - $5.(expression.Expression), - $7.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "SECOND" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "SUBSTRING" '(' Expression ',' Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), } } | "SUBSTRING" '(' Expression "FROM" Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), } } | "SUBSTRING" '(' Expression ',' Expression ',' Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), - Len: $7.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), + Len: $7.(ast.ExprNode), } } | "SUBSTRING" '(' Expression "FROM" Expression "FOR" Expression ')' { - $$ = &expression.FunctionSubstring{ - StrExpr: $3.(expression.Expression), - Pos: $5.(expression.Expression), - Len: $7.(expression.Expression), + $$ = &ast.FuncSubstringExpr{ + StrExpr: $3.(ast.ExprNode), + Pos: $5.(ast.ExprNode), + Len: $7.(ast.ExprNode), } } | "SUBSTRING_INDEX" '(' Expression ',' Expression ',' Expression ')' { - $$ = &expression.FunctionSubstringIndex{ - StrExpr: $3.(expression.Expression), - Delim: $5.(expression.Expression), - Count: $7.(expression.Expression), + $$ = &ast.FuncSubstringIndexExpr{ + StrExpr: $3.(ast.ExprNode), + Delim: $5.(ast.ExprNode), + Count: $7.(ast.ExprNode), } } | "SYSDATE" '(' ExpressionOpt ')' { - args := []expression.Expression{} + args := []ast.ExprNode{} if $3 != nil { - args = append(args, $3.(expression.Expression)) - } - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 + args = append(args, $3.(ast.ExprNode)) } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "TRIM" '(' Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $3.(expression.Expression), + $$ = &ast.FuncTrimExpr{ + Str: $3.(ast.ExprNode), } } | "TRIM" '(' Expression "FROM" Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $5.(expression.Expression), - RemStr: $3.(expression.Expression), + $$ = &ast.FuncTrimExpr{ + Str: $5.(ast.ExprNode), + RemStr: $3.(ast.ExprNode), } } | "TRIM" '(' TrimDirection "FROM" Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $5.(expression.Expression), - Direction: $3.(int), + $$ = &ast.FuncTrimExpr{ + Str: $5.(ast.ExprNode), + Direction: $3.(ast.TrimDirectionType), } } | "TRIM" '(' TrimDirection Expression "FROM" Expression ')' { - $$ = &expression.FunctionTrim{ - Str: $6.(expression.Expression), - RemStr: $4.(expression.Expression), - Direction: $3.(int), + $$ = &ast.FuncTrimExpr{ + Str: $6.(ast.ExprNode), + RemStr: $4.(ast.ExprNode), + Direction: $3.(ast.TrimDirectionType), } } | "UPPER" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKDAY" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKOFYEAR" '(' Expression ')' { - args := []expression.Expression{$3.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "YEARWEEK" '(' ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $3.([]expression.Expression),false) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } TrimDirection: "BOTH" { - $$ = expression.TrimBoth + $$ = ast.TrimBoth } | "LEADING" { - $$ = expression.TrimLeading + $$ = ast.TrimLeading } | "TRAILING" { - $$ = expression.TrimTrailing + $$ = ast.TrimTrailing } FunctionCallAgg: - "COUNT" '(' DistinctOpt ExpressionList ')' + "AVG" '(' DistinctOpt ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $4.([]expression.Expression), $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + } +| "COUNT" '(' DistinctOpt ExpressionList ')' + { + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt '*' ')' { - var err error - args := []expression.Expression{ expression.Value{Val: expression.TypeStar("*")} } - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + args := []ast.ExprNode{ast.NewValueExpr(ast.UnquoteString("*"))} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' { - var err error - $$, err = expression.NewCall($1.(string), $4.([]expression.Expression),$3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "MAX" '(' DistinctOpt Expression ')' { - args := []expression.Expression{$4.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "MIN" '(' DistinctOpt Expression ')' { - args := []expression.Expression{$4.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "SUM" '(' DistinctOpt Expression ')' { - args := []expression.Expression{$4.(expression.Expression)} - var err error - $$, err = expression.NewCall($1.(string), args, $3.(bool)) - if err != nil { - l := yylex.(*lexer) - l.err(err) - return 1 - } + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } FuncDatetimePrec: @@ -2730,19 +2397,19 @@ ExpressionOpt: WhenClauseList: WhenClause { - $$ = []*expression.WhenClause{$1.(*expression.WhenClause)} + $$ = []*ast.WhenClause{$1.(*ast.WhenClause)} } | WhenClauseList WhenClause { - $$ = append($1.([]*expression.WhenClause), $2.(*expression.WhenClause)) + $$ = append($1.([]*ast.WhenClause), $2.(*ast.WhenClause)) } WhenClause: "WHEN" Expression "THEN" Expression { - $$ = &expression.WhenClause{ - Expr: $2.(expression.Expression), - Result: $4.(expression.Expression), + $$ = &ast.WhenClause{ + Expr: $2.(ast.ExprNode), + Result: $4.(ast.ExprNode), } } @@ -2788,7 +2455,7 @@ CastType: } | "DECIMAL" FloatOpt { - fopt := $2.(*coldef.FloatOpt) + fopt := $2.(*ast.FloatOpt) x := types.NewFieldType(mysql.TypeNewDecimal) x.Flen = fopt.Flen x.Decimal = fopt.Decimal @@ -2816,69 +2483,70 @@ CastType: PrimaryFactor: PrimaryFactor '|' PrimaryFactor %prec '|' { - $$ = expression.NewBinaryOperation(opcode.Or, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Or, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '&' PrimaryFactor %prec '&' { - $$ = expression.NewBinaryOperation(opcode.And, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.And, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor "<<" PrimaryFactor %prec lsh { - $$ = expression.NewBinaryOperation(opcode.LeftShift, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.LeftShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor ">>" PrimaryFactor %prec rsh { - $$ = expression.NewBinaryOperation(opcode.RightShift, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.RightShift, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '+' PrimaryFactor %prec '+' { - $$ = expression.NewBinaryOperation(opcode.Plus, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Plus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '-' PrimaryFactor %prec '-' { - $$ = expression.NewBinaryOperation(opcode.Minus, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Minus, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '*' PrimaryFactor %prec '*' { - $$ = expression.NewBinaryOperation(opcode.Mul, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Mul, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '/' PrimaryFactor %prec '/' { - $$ = expression.NewBinaryOperation(opcode.Div, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Div, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '%' PrimaryFactor %prec '%' { - $$ = expression.NewBinaryOperation(opcode.Mod, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor "DIV" PrimaryFactor %prec div { - $$ = expression.NewBinaryOperation(opcode.IntDiv, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.IntDiv, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor "MOD" PrimaryFactor %prec mod { - $$ = expression.NewBinaryOperation(opcode.Mod, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Mod, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryFactor '^' PrimaryFactor { - $$ = expression.NewBinaryOperation(opcode.Xor, $1.(expression.Expression), $3.(expression.Expression)) + $$ = &ast.BinaryOperationExpr{Op: opcode.Xor, L: $1.(ast.ExprNode), R: $3.(ast.ExprNode)} } | PrimaryExpression + Priority: { - $$ = stmts.NoPriority + $$ = ast.NoPriority } | "LOW_PRIORITY" { - $$ = stmts.LowPriority + $$ = ast.LowPriority } | "HIGH_PRIORITY" { - $$ = stmts.HighPriority + $$ = ast.HighPriority } | "DELAYED" { - $$ = stmts.DelayedPriority + $$ = ast.DelayedPriority } LowPriorityOptional: @@ -2890,55 +2558,25 @@ LowPriorityOptional: $$ = true } -QualifiedIdent: +TableName: Identifier -| Identifier '.' '*' { - $$ = fmt.Sprintf("%s.*", $1.(string)) + $$ = &ast.TableName{Name:model.NewCIStr($1.(string))} } | Identifier '.' Identifier { - $$ = fmt.Sprintf("%s.%s", $1.(string), $3.(string)) - } -| Identifier '.' Identifier '.' Identifier - { - $$ = fmt.Sprintf("%s.%s.%s", $1.(string), $3.(string), $5.(string)) - } -| Identifier '.' Identifier '.' '*' - { - $$ = fmt.Sprintf("%s.%s.*", $1.(string), $3.(string)) + $$ = &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} } -SimpleQualifiedIdent: - Identifier -| Identifier '.' Identifier +TableNameList: + TableName { - $$ = fmt.Sprintf("%s.%s", $1.(string), $3.(string)) - } -| Identifier '.' Identifier '.' Identifier - { - $$ = fmt.Sprintf("%s.%s.%s", $1.(string), $3.(string), $5.(string)) - } - -TableIdent: - Identifier - { - $$ = table.Ident{Name:model.NewCIStr($1.(string))} - } -| Identifier '.' Identifier - { - $$ = table.Ident{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - } - -TableIdentList: - TableIdent - { - tbl := []table.Ident{$1.(table.Ident)} + tbl := []*ast.TableName{$1.(*ast.TableName)} $$ = tbl } -| TableIdentList ',' TableIdent +| TableNameList ',' TableName { - $$ = append($1.([]table.Ident), $3.(table.Ident)) + $$ = append($1.([]*ast.TableName), $3.(*ast.TableName)) } QuickOptional: @@ -2951,7 +2589,6 @@ QuickOptional: $$ = true } - /***************************Prepared Statement Start****************************** * See: https://dev.mysql.com/doc/refman/5.7/en/prepare.html * Example: @@ -2965,14 +2602,14 @@ PreparedStmt: "PREPARE" Identifier "FROM" PrepareSQL { var sqlText string - var sqlVar *expression.Variable + var sqlVar *ast.VariableExpr switch $4.(type) { case string: sqlText = $4.(string) - case *expression.Variable: - sqlVar = $4.(*expression.Variable) + case *ast.VariableExpr: + sqlVar = $4.(*ast.VariableExpr) } - $$ = &stmts.PreparedStmt{ + $$ = &ast.PrepareStmt{ InPrepare: true, Name: $2.(string), SQLText: sqlText, @@ -2995,24 +2632,24 @@ PrepareSQL: ExecuteStmt: "EXECUTE" Identifier { - $$ = &stmts.ExecuteStmt{Name: $2.(string)} + $$ = &ast.ExecuteStmt{Name: $2.(string)} } | "EXECUTE" Identifier "USING" UserVariableList { - $$ = &stmts.ExecuteStmt{ + $$ = &ast.ExecuteStmt{ Name: $2.(string), - UsingVars: $4.([]expression.Expression), + UsingVars: $4.([]ast.ExprNode), } } UserVariableList: UserVariable { - $$ = []expression.Expression{$1.(expression.Expression)} + $$ = []ast.ExprNode{$1.(ast.ExprNode)} } | UserVariableList ',' UserVariable { - $$ = append($1.([]expression.Expression), $3.(expression.Expression)) + $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) } /* @@ -3022,7 +2659,7 @@ UserVariableList: DeallocateStmt: DeallocateSym "PREPARE" Identifier { - $$ = &stmts.DeallocateStmt{Name: $3.(string)} + $$ = &ast.DeallocateStmt{Name: $3.(string)} } DeallocateSym: @@ -3034,93 +2671,122 @@ DeallocateSym: RollbackStmt: "ROLLBACK" { - $$ = &stmts.RollbackStmt{} + $$ = &ast.RollbackStmt{} } SelectStmt: "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt { - $$ = &stmts.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*field.Field), - Lock: $5.(coldef.LockType), + Fields: $3.(*ast.FieldList), + LockTp: $5.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + src := yylex.(*lexer).src + var lastEnd int + if $4 != nil { + lastEnd = yyS[yypt-1].offset-1 + } else if $5 != ast.SelectLockNone { + lastEnd = yyS[yypt].offset-1 + } else { + lastEnd = len(src) + if src[lastEnd-1] == ';' { + lastEnd-- + } + } + lastField.SetText(src[lastField.Offset:lastEnd]) + } + if $4 != nil { + st.Limit = $4.(*ast.Limit) + } + $$ = st } | "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt { - st := &stmts.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*field.Field), - From: nil, - Lock: $7.(coldef.LockType), + Fields: $3.(*ast.FieldList), + LockTp: $7.(ast.SelectLockType), + } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-3].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) } - if $5 != nil { - st.Where = &rsets.WhereRset{Expr: $5.(expression.Expression)} + st.Where = $5.(ast.ExprNode) + } + if $6 != nil { + st.Limit = $6.(*ast.Limit) } - $$ = st } -| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause WhereClauseOptional SelectStmtGroup HavingClause SelectStmtOrder +| "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" + TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional SelectStmtLimit SelectLockOpt { - st := &stmts.SelectStmt{ + st := &ast.SelectStmt{ Distinct: $2.(bool), - Fields: $3.([]*field.Field), - From: $5.(*rsets.JoinRset), - Lock: $11.(coldef.LockType), + Fields: $3.(*ast.FieldList), + From: $5.(*ast.TableRefsClause), + LockTp: $11.(ast.SelectLockType), + } + + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-7].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) } if $6 != nil { - st.Where = &rsets.WhereRset{Expr: $6.(expression.Expression)} + st.Where = $6.(ast.ExprNode) } if $7 != nil { - st.GroupBy = $7.(*rsets.GroupByRset) + st.GroupBy = $7.(*ast.GroupByClause) } if $8 != nil { - st.Having = $8.(*rsets.HavingRset) + st.Having = $8.(*ast.HavingClause) } if $9 != nil { - st.OrderBy = $9.(*rsets.OrderByRset) + st.OrderBy = $9.(*ast.OrderByClause) } if $10 != nil { - ay := $10.([]interface{}) - st.Limit = ay[0].(*rsets.LimitRset) - st.Offset = ay[1].(*rsets.OffsetRset) + st.Limit = $10.(*ast.Limit) } - + $$ = st } FromDual: - "FROM" "DUAL" + "FROM" "DUAL" -FromClause: +TableRefsClause: TableRefs { - $$ = $1 + $$ = &ast.TableRefsClause{TableRefs: $1.(*ast.Join)} } TableRefs: EscapedTableRef { - r := &rsets.JoinRset{Left: $1, Right: nil} - if j, ok := $1.(*rsets.JoinRset); ok { - // if $1 is JoinRset, use it directly - r = j - } - $$ = r + if j, ok := $1.(*ast.Join); ok { + // if $1 is Join, use it directly + $$ = j + } else { + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} + } } | TableRefs ',' EscapedTableRef { /* from a, b is default cross join */ - $$ = &rsets.JoinRset{Left: $1, Right: $3, Type: rsets.CrossJoin} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} } EscapedTableRef: @@ -3148,66 +2814,70 @@ TableRef: } TableFactor: - TableIdent TableIdentOpt + TableName TableAsNameOpt { - $$ = &rsets.TableSource{Source: $1, Name: $2.(string)} + $$ = &ast.TableSource{Source: $1.(*ast.TableName), AsName: $2.(model.CIStr)} } -| '(' SelectStmt ')' TableAsOpt +| '(' SelectStmt ')' TableAsName { - $$ = &rsets.TableSource{Source: $2, Name: $4.(string)} + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) + $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } -| '(' UnionStmt ')' TableAsOpt +| '(' UnionStmt ')' TableAsName { - $$ = &rsets.TableSource{Source: $2, Name: $4.(string)} + $$ = &ast.TableSource{Source: $2.(*ast.UnionStmt), AsName: $4.(model.CIStr)} } | '(' TableRefs ')' { $$ = $2 } -TableIdentOpt: +TableAsNameOpt: { - $$ = "" + $$ = model.CIStr{} } -| TableAsOpt +| TableAsName { $$ = $1 } -TableAsOpt: +TableAsName: Identifier { - $$ = $1 + $$ = model.NewCIStr($1.(string)) } | "AS" Identifier { - $$ = $2 + $$ = model.NewCIStr($2.(string)) } JoinTable: /* Use %prec to evaluate production TableRef before cross join */ TableRef CrossOpt TableRef %prec tableRefPriority { - $$ = &rsets.JoinRset{Left: $1, Right: $3, Type: rsets.CrossJoin} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin} } | TableRef CrossOpt TableRef "ON" Expression { - $$ = &rsets.JoinRset{Left: $1, Right: $3, Type: rsets.CrossJoin, On: $5.(expression.Expression)} + on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} } | TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - $$ = &rsets.JoinRset{Left: $1, Right: $5, Type: $2.(rsets.JoinType), On: $7.(expression.Expression)} - } + { + on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} + } /* Support Using */ JoinType: "LEFT" { - $$ = rsets.LeftJoin + $$ = ast.LeftJoin } | "RIGHT" { - $$ = rsets.RightJoin + $$ = ast.RightJoin } OuterOpt: @@ -3225,11 +2895,11 @@ CrossOpt: LimitClause: { - $$ = (*rsets.LimitRset)(nil) + $$ = nil } | "LIMIT" LengthNum { - $$ = &rsets.LimitRset{Count: $2.(uint64)} + $$ = &ast.Limit{Count: $2.(uint64)} } SelectStmtLimit: @@ -3238,21 +2908,15 @@ SelectStmtLimit: } | "LIMIT" LengthNum { - $$ = []interface{}{ - &rsets.LimitRset{Count: $2.(uint64)}, - (*rsets.OffsetRset)(nil)} + $$ = &ast.Limit{Count: $2.(uint64)} } | "LIMIT" LengthNum ',' LengthNum { - $$ = []interface{}{ - &rsets.LimitRset{Count: $4.(uint64)}, - &rsets.OffsetRset{Count: $2.(uint64)}} + $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} } | "LIMIT" LengthNum "OFFSET" LengthNum { - $$ = []interface{}{ - &rsets.LimitRset{Count: $2.(uint64)}, - &rsets.OffsetRset{Count: $4.(uint64)}} + $$ = &ast.Limit{Offset: $4.(uint64), Count: $2.(uint64)} } SelectStmtDistinct: @@ -3289,7 +2953,7 @@ SelectStmtCalcFoundRows: SelectStmtFieldList: FieldList { - $$ = $1 + $$ = &ast.FieldList{Fields: $1.([]*ast.SelectField)} } SelectStmtGroup: @@ -3299,94 +2963,157 @@ SelectStmtGroup: } | GroupByClause -SelectStmtOrder: - /* EMPTY */ - { - $$ = nil - } -| OrderBy - - // See: https://dev.mysql.com/doc/refman/5.7/en/subqueries.html SubSelect: '(' SelectStmt ')' { - s := $2.(*stmts.SelectStmt) + s := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &subquery.SubQuery{Stmt: s} + $$ = &ast.SubqueryExpr{Query: s} } | '(' UnionStmt ')' { - s := $2.(*stmts.UnionStmt) + s := $2.(*ast.UnionStmt) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) - $$ = &subquery.SubQuery{Stmt: s} + $$ = &ast.SubqueryExpr{Query: s} } // See: https://dev.mysql.com/doc/refman/5.7/en/innodb-locking-reads.html SelectLockOpt: /* empty */ { - $$ = coldef.SelectLockNone + $$ = ast.SelectLockNone } | "FOR" "UPDATE" { - $$ = coldef.SelectLockForUpdate + $$ = ast.SelectLockForUpdate } | "LOCK" "IN" "SHARE" "MODE" { - $$ = coldef.SelectLockInShareMode + $$ = ast.SelectLockInShareMode } +// See: https://dev.mysql.com/doc/refman/5.7/en/union.html +UnionStmt: + UnionClauseList "UNION" UnionOpt SelectStmt + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union + } +| UnionClauseList "UNION" UnionOpt '(' SelectStmt ')' OrderByOptional SelectStmtLimit + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) + st := $5.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, st) + if $7 != nil { + union.OrderBy = $7.(*ast.OrderByClause) + } + if $8 != nil { + union.Limit = $8.(*ast.Limit) + } + $$ = union + } + +UnionClauseList: + UnionSelect + { + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Selects: selects, + } + } +| UnionClauseList "UNION" UnionOpt UnionSelect + { + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union + } + +UnionSelect: + SelectStmt +| '(' SelectStmt ')' + { + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) + $$ = st + } + +UnionOpt: + { + $$ = true + } +| "ALL" + { + $$ = false + } +| "DISTINCT" + { + $$ = true + } + + /********************Set Statement*******************************/ SetStmt: "SET" VariableAssignmentList { - $$ = &stmts.SetStmt{Variables: $2.([]*stmts.VariableAssignment)} + $$ = &ast.SetStmt{Variables: $2.([]*ast.VariableAssignment)} } -| "SET" "NAMES" CharsetName +| "SET" "NAMES" StringName { - $$ = &stmts.SetCharsetStmt{Charset: $3.(string)} + $$ = &ast.SetCharsetStmt{Charset: $3.(string)} } -| "SET" "NAMES" CharsetName "COLLATE" CollationName +| "SET" "NAMES" StringName "COLLATE" StringName { - $$ = &stmts.SetCharsetStmt{ + $$ = &ast.SetCharsetStmt{ Charset: $3.(string), Collate: $5.(string), } } -| "SET" CharsetKw CharsetName +| "SET" CharsetKw StringName { - $$ = &stmts.SetCharsetStmt{Charset: $3.(string)} + $$ = &ast.SetCharsetStmt{Charset: $3.(string)} } | "SET" "PASSWORD" eq PasswordOpt { - $$ = &stmts.SetPwdStmt{Password: $4.(string)} + $$ = &ast.SetPwdStmt{Password: $4.(string)} } | "SET" "PASSWORD" "FOR" Username eq PasswordOpt { - $$ = &stmts.SetPwdStmt{User: $4.(string), Password: $6.(string)} + $$ = &ast.SetPwdStmt{User: $4.(string), Password: $6.(string)} } VariableAssignment: Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $1.(string), Value: $3.(expression.Expression), IsSystem: true} + $$ = &ast.VariableAssignment{Name: $1.(string), Value: $3.(ast.ExprNode), IsSystem: true} } | "GLOBAL" Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $2.(string), Value: $4.(expression.Expression), IsGlobal: true, IsSystem: true} + $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsGlobal: true, IsSystem: true} } | "SESSION" Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $2.(string), Value: $4.(expression.Expression), IsSystem: true} + $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} } | "LOCAL" Identifier eq Expression { - $$ = &stmts.VariableAssignment{Name: $2.(string), Value: $4.(expression.Expression), IsSystem: true} + $$ = &ast.VariableAssignment{Name: $2.(string), Value: $4.(ast.ExprNode), IsSystem: true} } | "SYS_VAR" eq Expression { @@ -3402,26 +3129,26 @@ VariableAssignment: } else if strings.HasPrefix(v, "@@") { v = strings.TrimPrefix(v, "@@") } - $$ = &stmts.VariableAssignment{Name: v, Value: $3.(expression.Expression), IsGlobal: isGlobal, IsSystem: true} + $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode), IsGlobal: isGlobal, IsSystem: true} } | "USER_VAR" eq Expression { v := $1.(string) v = strings.TrimPrefix(v, "@") - $$ = &stmts.VariableAssignment{Name: v, Value: $3.(expression.Expression)} + $$ = &ast.VariableAssignment{Name: v, Value: $3.(ast.ExprNode)} } VariableAssignmentList: { - $$ = []*stmts.VariableAssignment{} + $$ = []*ast.VariableAssignment{} } | VariableAssignment { - $$ = []*stmts.VariableAssignment{$1.(*stmts.VariableAssignment)} + $$ = []*ast.VariableAssignment{$1.(*ast.VariableAssignment)} } | VariableAssignmentList ',' VariableAssignment { - $$ = append($1.([]*stmts.VariableAssignment), $3.(*stmts.VariableAssignment)) + $$ = append($1.([]*ast.VariableAssignment), $3.(*ast.VariableAssignment)) } Variable: @@ -3442,7 +3169,7 @@ SystemVariable: } else if strings.HasPrefix(v, "@@") { v = strings.TrimPrefix(v, "@@") } - $$ = &expression.Variable{Name: v, IsGlobal: isGlobal, IsSystem: true} + $$ = &ast.VariableExpr{Name: v, IsGlobal: isGlobal, IsSystem: true} } UserVariable: @@ -3450,7 +3177,7 @@ UserVariable: { v := $1.(string) v = strings.TrimPrefix(v, "@") - $$ = &expression.Variable{Name: v, IsGlobal: false, IsSystem: false} + $$ = &ast.VariableExpr{Name: v, IsGlobal: false, IsSystem: false} } Username: @@ -3479,77 +3206,91 @@ AuthString: ShowStmt: "SHOW" "ENGINES" { - $$ = &stmts.ShowStmt{Target: stmt.ShowEngines} + $$ = &ast.ShowStmt{Tp: ast.ShowEngines} } | "SHOW" "DATABASES" { - $$ = &stmts.ShowStmt{Target: stmt.ShowDatabases} + $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} } | "SHOW" "SCHEMAS" { - $$ = &stmts.ShowStmt{Target: stmt.ShowDatabases} + $$ = &ast.ShowStmt{Tp: ast.ShowDatabases} } | "SHOW" "CHARACTER" "SET" { - $$ = &stmts.ShowStmt{Target: stmt.ShowCharset} + $$ = &ast.ShowStmt{Tp: ast.ShowCharset} } | "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt { - stmt := &stmts.ShowStmt{ - Target: stmt.ShowTables, - DBName: $4.(string), - Full: $2.(bool), + stmt := &ast.ShowStmt{ + Tp: ast.ShowTables, + DBName: $4.(string), + Full: $2.(bool), + } + if $5 != nil { + if x, ok := $5.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { + stmt.Where = $5.(ast.ExprNode) + } } - stmt.SetCondition($5) $$ = stmt } -| "SHOW" OptFull "COLUMNS" ShowTableIdentOpt ShowDatabaseNameOpt +| "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt { - $$ = &stmts.ShowStmt{ - Target: stmt.ShowColumns, - TableIdent: $4.(table.Ident), - DBName: $5.(string), - Full: $2.(bool), + $$ = &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $4.(*ast.TableName), + DBName: $5.(string), + Full: $2.(bool), } } | "SHOW" "WARNINGS" { - $$ = &stmts.ShowStmt{Target: stmt.ShowWarnings} + $$ = &ast.ShowStmt{Tp: ast.ShowWarnings} } | "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt { - stmt := &stmts.ShowStmt{ - Target: stmt.ShowVariables, + stmt := &ast.ShowStmt{ + Tp: ast.ShowVariables, GlobalScope: $2.(bool), } - stmt.SetCondition($4) + if x, ok := $4.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else if $4 != nil { + stmt.Where = $4.(ast.ExprNode) + } $$ = stmt } | "SHOW" "COLLATION" ShowLikeOrWhereOpt { - stmt := &stmts.ShowStmt{ - Target: stmt.ShowCollation, + stmt := &ast.ShowStmt{ + Tp: ast.ShowCollation, + } + if x, ok := $3.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else if $3 != nil { + stmt.Where = $3.(ast.ExprNode) } - stmt.SetCondition($3) $$ = stmt } -| "SHOW" "CREATE" "TABLE" TableIdent +| "SHOW" "CREATE" "TABLE" TableName { - $$ = &stmts.ShowStmt{ - Target: stmt.ShowCreateTable, - TableIdent: $4.(table.Ident), + $$ = &ast.ShowStmt{ + Tp: ast.ShowCreateTable, + Table: $4.(*ast.TableName), } } | "SHOW" "GRANTS" { // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html - $$ = &stmts.ShowStmt{Target: stmt.ShowGrants} + $$ = &ast.ShowStmt{Tp: ast.ShowGrants} } | "SHOW" "GRANTS" "FOR" Username { // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html - $$ = &stmts.ShowStmt{ - Target: stmt.ShowGrants, + $$ = &ast.ShowStmt{ + Tp: ast.ShowGrants, User: $4.(string), } } @@ -3560,11 +3301,11 @@ ShowLikeOrWhereOpt: } | "LIKE" PrimaryExpression { - $$ = &expression.PatternLike{Pattern: $2.(expression.Expression)} + $$ = &ast.PatternLikeExpr{Pattern: $2.(ast.ExprNode)} } | "WHERE" Expression { - $$ = expression.Expr($2) + $$ = $2.(ast.ExprNode) } GlobalScope: @@ -3602,14 +3343,14 @@ ShowDatabaseNameOpt: $$ = $2.(string) } -ShowTableIdentOpt: - "FROM" TableIdent +ShowTableAliasOpt: + "FROM" TableName { - $$ = $2.(table.Ident) + $$ = $2.(*ast.TableName) } -| "IN" TableIdent +| "IN" TableName { - $$ = $2.(table.Ident) + $$ = $2.(*ast.TableName) } Statement: @@ -3635,17 +3376,17 @@ Statement: | RollbackStmt | ReplaceIntoStmt | SelectStmt +| UnionStmt | SetStmt | ShowStmt | TruncateTableStmt -| UnionStmt | UpdateStmt | UseStmt | SubSelect { // `(select 1)`; is a valid select statement // TODO: This is used to fix issue #320. There may be a better solution. - $$ = $1.(*subquery.SubQuery).Stmt + $$ = $1.(*ast.SubqueryExpr).Query } ExplainableStmt: @@ -3659,46 +3400,38 @@ StatementList: Statement { if $1 != nil { - n, ok := $1.(ast.Node) - if ok { - n.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = []interface{}{n} - } else { - s := $1.(stmt.Statement) - s.SetText(yylex.(*lexer).stmtText()) - yylex.(*lexer).list = []interface{}{s} - } + s := $1.(ast.StmtNode) + s.SetText(yylex.(*lexer).stmtText()) + yylex.(*lexer).list = append(yylex.(*lexer).list, s) } } | StatementList ';' Statement { if $3 != nil { - s := $3.(stmt.Statement) + s := $3.(ast.StmtNode) s.SetText(yylex.(*lexer).stmtText()) yylex.(*lexer).list = append(yylex.(*lexer).list, s) } } -TableConstraint: - "CONSTRAINT" name ConstraintElem +Constraint: + ConstraintKeywordOpt ConstraintElem { - cst := $3.(*coldef.TableConstraint) - cst.ConstrName = $2.(string) + cst := $2.(*ast.Constraint) + if $1 != nil { + cst.Name = $1.(string) + } $$ = cst } -| ConstraintElem - { - $$ = $1 - } TableElement: ColumnDef { - $$ = $1.(*coldef.ColumnDef) + $$ = $1.(*ast.ColumnDef) } -| TableConstraint +| Constraint { - $$ = $1.(*coldef.TableConstraint) + $$ = $1.(*ast.Constraint) } | "CHECK" '(' Expression ')' { @@ -3724,96 +3457,90 @@ TableElementList: } } -TableElementListOpt: - { - $$ = []interface{}{} - } -| TableElementList - -TableOpt: +TableOption: "ENGINE" Identifier { - $$ = &coldef.TableOpt{Tp: coldef.TblOptEngine, StrValue: $2.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $2.(string)} } | "ENGINE" eq Identifier { - $$ = &coldef.TableOpt{Tp: coldef.TblOptEngine, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionEngine, StrValue: $3.(string)} } -| DefaultKwdOpt CharsetKw EqOpt CharsetName +| DefaultKwdOpt CharsetKw EqOpt StringName { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCharset, StrValue: $4.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionCharset, StrValue: $4.(string)} } -| DefaultKwdOpt "COLLATE" EqOpt CollationName +| DefaultKwdOpt "COLLATE" EqOpt StringName { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCollate, StrValue: $4.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionCollate, StrValue: $4.(string)} } | "AUTO_INCREMENT" eq LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptAutoIncrement, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionAutoIncrement, UintValue: $3.(uint64)} } | "COMMENT" EqOpt stringLit { - $$ = &coldef.TableOpt{Tp: coldef.TblOptComment, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionComment, StrValue: $3.(string)} } | "AVG_ROW_LENGTH" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptAvgRowLength, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionAvgRowLength, UintValue: $3.(uint64)} } | "CONNECTION" EqOpt stringLit { - $$ = &coldef.TableOpt{Tp: coldef.TblOptConnection, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionConnection, StrValue: $3.(string)} } | "CHECKSUM" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCheckSum, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionCheckSum, UintValue: $3.(uint64)} } | "PASSWORD" EqOpt stringLit { - $$ = &coldef.TableOpt{Tp: coldef.TblOptPassword, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionPassword, StrValue: $3.(string)} } | "COMPRESSION" EqOpt Identifier { - $$ = &coldef.TableOpt{Tp: coldef.TblOptCompression, StrValue: $3.(string)} + $$ = &ast.TableOption{Tp: ast.TableOptionCompression, StrValue: $3.(string)} } | "KEY_BLOCK_SIZE" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptKeyBlockSize, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionKeyBlockSize, UintValue: $3.(uint64)} } | "MAX_ROWS" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptMaxRows, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionMaxRows, UintValue: $3.(uint64)} } | "MIN_ROWS" EqOpt LengthNum { - $$ = &coldef.TableOpt{Tp: coldef.TblOptMinRows, UintValue: $3.(uint64)} + $$ = &ast.TableOption{Tp: ast.TableOptionMinRows, UintValue: $3.(uint64)} } -TableOptListOpt: +TableOptionListOpt: { - $$ = []*coldef.TableOpt{} + $$ = []*ast.TableOption{} } -| TableOptList %prec lowerThanComma +| TableOptionList %prec lowerThanComma -TableOptList: - TableOpt +TableOptionList: + TableOption { - $$ = []*coldef.TableOpt{$1.(*coldef.TableOpt)} + $$ = []*ast.TableOption{$1.(*ast.TableOption)} } -| TableOptList TableOpt +| TableOptionList TableOption { - $$ = append($1.([]*coldef.TableOpt), $2.(*coldef.TableOpt)) + $$ = append($1.([]*ast.TableOption), $2.(*ast.TableOption)) } -| TableOptList ',' TableOpt +| TableOptionList ',' TableOption { - $$ = append($1.([]*coldef.TableOpt), $3.(*coldef.TableOpt)) + $$ = append($1.([]*ast.TableOption), $3.(*ast.TableOption)) } TruncateTableStmt: - "TRUNCATE" "TABLE" TableIdent + "TRUNCATE" "TABLE" TableName { - $$ = &stmts.TruncateTableStmt{TableIdent: $3.(table.Ident)} + $$ = &ast.TruncateTableStmt{Table: $3.(*ast.TableName)} } /*************************************Type Begin***************************************/ @@ -3879,7 +3606,7 @@ NumericType: } | FixedPointType FloatOpt FieldOpts { - fopt := $2.(*coldef.FloatOpt) + fopt := $2.(*ast.FloatOpt) x := types.NewFieldType($1.(byte)) x.Flen = fopt.Flen x.Decimal = fopt.Decimal @@ -3895,7 +3622,7 @@ NumericType: } | FloatingPointType FloatOpt FieldOpts { - fopt := $2.(*coldef.FloatOpt) + fopt := $2.(*ast.FloatOpt) x := types.NewFieldType($1.(byte)) x.Flen = fopt.Flen if x.Tp == mysql.TypeFloat { @@ -3927,7 +3654,6 @@ NumericType: x.Flen = 1 } else if x.Flen > 64 { yylex.(*lexer).errf("invalid field length %d for bit type, must in [1, 64]", x.Flen) - return 1 } $$ = x } @@ -4011,8 +3737,6 @@ StringType: if $4.(bool) { x.Flag |= mysql.BinaryFlag } - x.Charset = $5.(string) - x.Collate = $6.(string) $$ = x } | NationalOpt "CHAR" OptBinary OptCharset OptCollate @@ -4021,8 +3745,6 @@ StringType: if $3.(bool) { x.Flag |= mysql.BinaryFlag } - x.Charset = $4.(string) - x.Collate = $5.(string) $$ = x } | NationalOpt "VARCHAR" FieldLen OptBinary OptCharset OptCollate @@ -4216,21 +3938,21 @@ FieldOpts: FloatOpt: { - $$ = &coldef.FloatOpt{Flen: types.UnspecifiedLength, Decimal: types.UnspecifiedLength} + $$ = &ast.FloatOpt{Flen: types.UnspecifiedLength, Decimal: types.UnspecifiedLength} } | FieldLen { - $$ = &coldef.FloatOpt{Flen: $1.(int), Decimal: types.UnspecifiedLength} + $$ = &ast.FloatOpt{Flen: $1.(int), Decimal: types.UnspecifiedLength} } | Precision { - $$ = $1.(*coldef.FloatOpt) + $$ = $1.(*ast.FloatOpt) } Precision: '(' LengthNum ',' LengthNum ')' { - $$ = &coldef.FloatOpt{Flen: int($2.(uint64)), Decimal: int($4.(uint64))} + $$ = &ast.FloatOpt{Flen: int($2.(uint64)), Decimal: int($4.(uint64))} } OptBinary: @@ -4246,7 +3968,7 @@ OptCharset: { $$ = "" } -| CharsetKw CharsetName +| CharsetKw StringName { $$ = $2.(string) } @@ -4259,7 +3981,7 @@ OptCollate: { $$ = "" } -| "COLLATE" CollationName +| "COLLATE" StringName { $$ = $2.(string) } @@ -4274,71 +3996,15 @@ StringList: $$ = append($1.([]string), $3.(string)) } -/************************************************************************************ - * Union statement - * See: https://dev.mysql.com/doc/refman/5.7/en/union.html - ***********************************************************************************/ -UnionStmt: - UnionSelect "UNION" UnionOpt SelectStmt +StringName: + stringLit { - ds := []bool {$3.(bool)} - ss := []*stmts.SelectStmt{$1.(*stmts.SelectStmt), $4.(*stmts.SelectStmt)} - $$ = &stmts.UnionStmt{ - Distincts: ds, - Selects: ss, - } + $$ = $1.(string) } -| UnionSelect "UNION" UnionOpt '(' SelectStmt ')' SelectStmtOrder SelectStmtLimit +| Identifier { - ds := []bool {$3.(bool)} - ss := []*stmts.SelectStmt{$1.(*stmts.SelectStmt), $5.(*stmts.SelectStmt)} - st := &stmts.UnionStmt{ - Distincts: ds, - Selects: ss, - } - if $7 != nil { - st.OrderBy = $7.(*rsets.OrderByRset) - } - - if $8 != nil { - ay := $8.([]interface{}) - st.Limit = ay[0].(*rsets.LimitRset) - st.Offset = ay[1].(*rsets.OffsetRset) - } - $$ = st + $$ = $1.(string) } -| UnionSelect "UNION" UnionOpt UnionStmt - { - s := $4.(*stmts.UnionStmt) - s.Distincts = append([]bool {$3.(bool)}, s.Distincts...) - s.Selects = append([]*stmts.SelectStmt{$1.(*stmts.SelectStmt)}, s.Selects...) - $$ = s - } - -UnionSelect: - SelectStmt - { - $$ = $1 - } -| '(' SelectStmt ')' - { - $$ = $2 - } - - -UnionOpt: - { - $$ = true - } -| "ALL" - { - $$ = false - } -| "DISTINCT" - { - $$ = true - } - /*********************************************************************************** * Update Statement @@ -4347,21 +4013,25 @@ UnionOpt: UpdateStmt: "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { - // Single-table syntax - r := &rsets.JoinRset{Left: $4, Right: nil} - st := &stmts.UpdateStmt{ + var refs *ast.Join + if x, ok := $4.(*ast.Join); ok { + refs = x + } else { + refs = &ast.Join{Left: $4.(ast.ResultSetNode)} + } + st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: r, - List: $6.([]*expression.Assignment), + TableRefs: &ast.TableRefsClause{TableRefs: refs}, + List: $6.([]*ast.Assignment), } if $7 != nil { - st.Where = $7.(expression.Expression) + st.Where = $7.(ast.ExprNode) } if $8 != nil { - st.Order = $8.(*rsets.OrderByRset) + st.Order = $8.(*ast.OrderByClause) } if $9 != nil { - st.Limit = $9.(*rsets.LimitRset) + st.Limit = $9.(*ast.Limit) } $$ = st if yylex.(*lexer).root { @@ -4370,15 +4040,13 @@ UpdateStmt: } | "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional { - // Multiple-table syntax - st := &stmts.UpdateStmt{ + st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: $4.(*rsets.JoinRset), - List: $6.([]*expression.Assignment), - MultipleTable: true, + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, + List: $6.([]*ast.Assignment), } if $7 != nil { - st.Where = $7.(expression.Expression) + st.Where = $7.(ast.ExprNode) } $$ = st if yylex.(*lexer).root { @@ -4389,7 +4057,7 @@ UpdateStmt: UseStmt: "USE" DBName { - $$ = &stmts.UseStmt{DBName: $2.(string)} + $$ = &ast.UseStmt{DBName: $2.(string)} if yylex.(*lexer).root { break } @@ -4398,7 +4066,7 @@ UseStmt: WhereClause: "WHERE" Expression { - $$ = expression.Expr($2) + $$ = $2.(ast.ExprNode) } WhereClauseOptional: @@ -4422,36 +4090,35 @@ CommaOpt: * https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html ************************************************************************************/ CreateUserStmt: - "CREATE" "USER" IfNotExists UserSpecificationList + "CREATE" "USER" IfNotExists UserSpecList { // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html - $$ = &stmts.CreateUserStmt{ + $$ = &ast.CreateUserStmt{ IfNotExists: $3.(bool), - Specs: $4.([]*coldef.UserSpecification), + Specs: $4.([]*ast.UserSpec), } } -UserSpecification: +UserSpec: Username AuthOption { - x := &coldef.UserSpecification{ + userSpec := &ast.UserSpec{ User: $1.(string), } if $2 != nil { - x.AuthOpt = $2.(*coldef.AuthOption) + userSpec.AuthOpt = $2.(*ast.AuthOption) } - $$ = x - + $$ = userSpec } -UserSpecificationList: - UserSpecification +UserSpecList: + UserSpec { - $$ = []*coldef.UserSpecification{$1.(*coldef.UserSpecification)} + $$ = []*ast.UserSpec{$1.(*ast.UserSpec)} } -| UserSpecificationList ',' UserSpecification +| UserSpecList ',' UserSpec { - $$ = append($1.([]*coldef.UserSpecification), $3.(*coldef.UserSpecification)) + $$ = append($1.([]*ast.UserSpec), $3.(*ast.UserSpec)) } AuthOption: @@ -4460,14 +4127,14 @@ AuthOption: } | "IDENTIFIED" "BY" AuthString { - $$ = &coldef.AuthOption { + $$ = &ast.AuthOption { AuthString: $3.(string), ByAuthString: true, - } + } } | "IDENTIFIED" "BY" "PASSWORD" HashString { - $$ = &coldef.AuthOption { + $$ = &ast.AuthOption{ HashString: $4.(string), } } @@ -4480,42 +4147,39 @@ HashString: * See: https://dev.mysql.com/doc/refman/5.7/en/grant.html *************************************************************************************/ GrantStmt: - "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecificationList + "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecList { - $$ = &stmts.GrantStmt{ - Privs: $2.([]*coldef.PrivElem), - ObjectType: $4.(int), - Level: $5.(*coldef.GrantLevel), - Users: $7.([]*coldef.UserSpecification), - } + $$ = &ast.GrantStmt{ + Privs: $2.([]*ast.PrivElem), + ObjectType: $4.(ast.ObjectTypeType), + Level: $5.(*ast.GrantLevel), + Users: $7.([]*ast.UserSpec), + } } PrivElem: - PrivType ColumnListOpt + PrivType { - $$ = &coldef.PrivElem{ + $$ = &ast.PrivElem{ Priv: $1.(mysql.PrivilegeType), - Cols: $2.([]string), } } - -ColumnListOpt: +| PrivType '(' ColumnNameList ')' { - $$ = []string{} - } -| '(' ColumnNameList ')' - { - $$ = $2 + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + Cols: $3.([]*ast.ColumnName), + } } PrivElemList: PrivElem { - $$ = []*coldef.PrivElem{$1.(*coldef.PrivElem)} + $$ = []*ast.PrivElem{$1.(*ast.PrivElem)} } | PrivElemList ',' PrivElem { - $$ = append($1.([]*coldef.PrivElem), $3.(*coldef.PrivElem)) + $$ = append($1.([]*ast.PrivElem), $3.(*ast.PrivElem)) } PrivType: @@ -4571,50 +4235,49 @@ PrivType: { $$ = mysql.GrantPriv } - + ObjectType: { - $$ = coldef.ObjectTypeNone + $$ = ast.ObjectTypeNone } | "TABLE" { - $$ = coldef.ObjectTypeTable + $$ = ast.ObjectTypeTable } PrivLevel: '*' { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelDB, - } + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + } } | '*' '.' '*' { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelGlobal, - } + $$ = &ast.GrantLevel { + Level: ast.GrantLevelGlobal, + } } | Identifier '.' '*' { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelDB, + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, DBName: $1.(string), - } + } } | Identifier '.' Identifier { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelTable, + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, DBName: $1.(string), TableName: $3.(string), - } + } } | Identifier { - $$ = &coldef.GrantLevel { - Level: coldef.GrantLevelTable, + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, TableName: $1.(string), - } + } } %% - diff --git a/parser/parser_test.go b/parser/parser_test.go index 8f5c0d10ee..50d6e5cf2d 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -18,9 +18,7 @@ import ( "testing" . "github.com/pingcap/check" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/subquery" - "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/ast" ) func TestT(t *testing.T) { @@ -32,23 +30,6 @@ var _ = Suite(&testParserSuite{}) type testParserSuite struct { } -func (s *testParserSuite) TestOriginText(c *C) { - src := `SELECT stuff.id - FROM stuff - WHERE stuff.value >= ALL (SELECT stuff.value - FROM stuff)` - - l := NewLexer(src) - c.Assert(yyParse(l), Equals, 0) - node := l.Stmts()[0].(*stmts.SelectStmt) - sq := node.Where.Expr.(*expression.CompareSubQuery).R - c.Assert(sq, NotNil) - subsel := sq.(*subquery.SubQuery) - c.Assert(subsel.Stmt.OriginText(), Equals, - `SELECT stuff.value - FROM stuff`) -} - func (s *testParserSuite) TestSimple(c *C) { // Testcase for unreserved keywords unreservedKws := []string{ @@ -70,15 +51,12 @@ func (s *testParserSuite) TestSimple(c *C) { // Testcase for prepared statement src := "SELECT id+?, id+? from t;" l := NewLexer(src) - l.SetPrepare() c.Assert(yyParse(l), Equals, 0) - c.Assert(len(l.ParamList), Equals, 2) c.Assert(len(l.Stmts()), Equals, 1) // Testcase for -- Comment and unary -- operator src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;" l = NewLexer(src) - l.SetPrepare() c.Assert(yyParse(l), Equals, 0) c.Assert(len(l.Stmts()), Equals, 2) @@ -87,11 +65,12 @@ func (s *testParserSuite) TestSimple(c *C) { l = NewLexer(src) c.Assert(yyParse(l), Equals, 0) st := l.Stmts()[0] - ss, ok := st.(*stmts.SelectStmt) + ss, ok := st.(*ast.SelectStmt) c.Assert(ok, IsTrue) - cv, ok := ss.Fields[0].Expr.(*expression.FunctionCast) + c.Assert(len(ss.Fields.Fields), Equals, 1) + cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) c.Assert(ok, IsTrue) - c.Assert(cv.FunctionType, Equals, expression.ConvertFunction) + c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) // For query start with comment srcs := []string{ @@ -105,7 +84,7 @@ func (s *testParserSuite) TestSimple(c *C) { l = NewLexer(src) c.Assert(yyParse(l), Equals, 0) st = l.Stmts()[0] - ss, ok = st.(*stmts.SelectStmt) + ss, ok = st.(*ast.SelectStmt) c.Assert(ok, IsTrue) } } @@ -172,9 +151,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {"REPLACE INTO foo () VALUES ()", true}, {"REPLACE INTO foo VALUE ()", true}, // 40 - {`SELECT stuff.id - FROM stuff - WHERE stuff.value >= ALL (SELECT stuff.value + {`SELECT stuff.id + FROM stuff + WHERE stuff.value >= ALL (SELECT stuff.value FROM stuff)`, true}, {"BEGIN", true}, {"START TRANSACTION", true}, diff --git a/parser/scanner.l b/parser/scanner.l index c1441b4650..191eaef35e 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -22,25 +22,25 @@ import ( "fmt" "math" "strconv" - "unicode" "strings" - - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/util/stringutil" - "github.com/pingcap/tidb/util/charset" + "unicode" + + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/charset" + "github.com/pingcap/tidb/util/stringutil" ) type lexer struct { c int col int errs []error - expr expression.Expression + expr ast.ExprNode i int inj int lcol int line int - list []interface{} + list []ast.StmtNode ncol int nline int sc int @@ -48,8 +48,7 @@ type lexer struct { val []byte ungetBuf []byte root bool - prepare bool - ParamList []*expression.ParamMarker + prepare bool stmtStartPos int stringLit []byte @@ -78,11 +77,11 @@ func (l *lexer) Errors() []error { return l.errs } -func (l *lexer) Stmts() []interface{}{ +func (l *lexer) Stmts() []ast.StmtNode { return l.list } -func (l *lexer) Expr() expression.Expression { +func (l *lexer) Expr() ast.ExprNode { return l.expr } @@ -90,16 +89,16 @@ func (l *lexer) Inj() int { return l.inj } +func (l *lexer) SetInj(inj int) { + l.inj = inj +} + func (l *lexer) SetPrepare() { - l.prepare = true + l.prepare = true } func (l *lexer) IsPrepare() bool { - return l.prepare -} - -func (l *lexer) SetInj(inj int) { - l.inj = inj + return l.prepare } func (l *lexer) Root() bool { @@ -116,7 +115,17 @@ func (l *lexer) SetCharsetInfo(charset, collation string) { } func (l *lexer) GetCharsetInfo() (string, string) { - return l.charset, l.collation + return l.charset, l.collation +} + +// The select statement is not at the end of the whole statement, if the last +// field text was set from its offset to the end of the src string, update +// the last field text. +func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 { + lastField.SetText(l.src[lastField.Offset:lastEnd]) + } } func (l *lexer) unget(b byte) { @@ -822,9 +831,9 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {repeat} lval.item = string(l.val) return repeat {regexp} return regexp -{references} return references {replace} lval.item = string(l.val) return replace +{references} return references {rlike} return rlike {sys_var} lval.item = string(l.val) diff --git a/parser/scanner_test.go b/parser/scanner_test.go deleted file mode 100644 index 0d920cbd5a..0000000000 --- a/parser/scanner_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package parser - -import ( - "fmt" - "unicode" - - . "github.com/pingcap/check" -) - -func tok2name(i int) string { - if i == unicode.ReplacementChar { - return "" - } - - if i < 128 { - return fmt.Sprintf("tok-'%c'", i) - } - - return fmt.Sprintf("tok-%d", i) -} - -func (s *testParserSuite) TestScaner0(c *C) { - table := []struct { - src string - tok, line, col, nline, ncol int - val string - }{ - {"a", identifier, 1, 1, 1, 2, "a"}, - {" a", identifier, 1, 2, 1, 3, "a"}, - {"a ", identifier, 1, 1, 1, 2, "a"}, - {" a ", identifier, 1, 2, 1, 3, "a"}, - {"\na", identifier, 2, 1, 2, 2, "a"}, - - {"a\n", identifier, 1, 1, 1, 2, "a"}, - {"\na\n", identifier, 2, 1, 2, 2, "a"}, - {"\n a", identifier, 2, 2, 2, 3, "a"}, - {"a \n", identifier, 1, 1, 1, 2, "a"}, - {"\n a \n", identifier, 2, 2, 2, 3, "a"}, - - {"ab", identifier, 1, 1, 1, 3, "ab"}, - {" ab", identifier, 1, 2, 1, 4, "ab"}, - {"ab ", identifier, 1, 1, 1, 3, "ab"}, - {" ab ", identifier, 1, 2, 1, 4, "ab"}, - {"\nab", identifier, 2, 1, 2, 3, "ab"}, - - {"ab\n", identifier, 1, 1, 1, 3, "ab"}, - {"\nab\n", identifier, 2, 1, 2, 3, "ab"}, - {"\n ab", identifier, 2, 2, 2, 4, "ab"}, - {"ab \n", identifier, 1, 1, 1, 3, "ab"}, - {"\n ab \n", identifier, 2, 2, 2, 4, "ab"}, - - {"c", identifier, 1, 1, 1, 2, "c"}, - {"cR", identifier, 1, 1, 1, 3, "cR"}, - {"cRe", identifier, 1, 1, 1, 4, "cRe"}, - {"cReA", identifier, 1, 1, 1, 5, "cReA"}, - {"cReAt", identifier, 1, 1, 1, 6, "cReAt"}, - - {"cReATe", create, 1, 1, 1, 7, "cReATe"}, - {"cReATeD", identifier, 1, 1, 1, 8, "cReATeD"}, - {"2", intLit, 1, 1, 1, 2, "2"}, - {"2.", floatLit, 1, 1, 1, 3, "2."}, - {"2.3", floatLit, 1, 1, 1, 4, "2.3"}, - } - - lval := &yySymType{} - for _, t := range table { - l := NewLexer(t.src) - tok := l.Lex(lval) - nline, ncol := l.npos() - val := string(l.val) - - c.Assert(tok, Equals, t.tok) - c.Assert(l.line, Equals, t.line) - c.Assert(l.col, Equals, t.col) - c.Assert(nline, Equals, t.nline) - c.Assert(ncol, Equals, t.ncol) - c.Assert(val, Equals, t.val) - } -} diff --git a/stmt/stmts/drop_test.go b/stmt/stmts/drop_test.go index 5d0a34ec0d..4a1de810c2 100644 --- a/stmt/stmts/drop_test.go +++ b/stmt/stmts/drop_test.go @@ -62,7 +62,7 @@ func (s *testStmtSuite) TestDropTable(c *C) { } func (s *testStmtSuite) TestDropIndex(c *C) { - testSQL := "drop index if exists drop_index;" + testSQL := "drop index if exists drop_index on t;" stmtList, err := tidb.Compile(s.ctx, testSQL) c.Assert(err, IsNil) diff --git a/stmt/stmts/show_test.go b/stmt/stmts/show_test.go index 5dc0c57828..e67813a46d 100644 --- a/stmt/stmts/show_test.go +++ b/stmt/stmts/show_test.go @@ -51,5 +51,5 @@ func (s *testStmtSuite) TestShow(c *C) { c.Assert(stmtList, HasLen, 1) testStmt, ok = stmtList[0].(*stmts.ShowStmt) c.Assert(ok, IsTrue) - c.Assert(testStmt.Pattern, NotNil) + c.Assert(testStmt.Pattern, NotNil, Commentf("S: %s", testStmt.Text)) } diff --git a/tidb.go b/tidb.go index 55581af782..30a95b608f 100644 --- a/tidb.go +++ b/tidb.go @@ -26,7 +26,6 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" @@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) { rawStmt := l.Stmts() stmts := make([]stmt.Statement, len(rawStmt)) for i, v := range rawStmt { - if node, ok := v.(ast.Node); ok { - stm, err := optimizer.Compile(node) - if err != nil { - return nil, errors.Trace(err) - } - stmts[i] = stm - } else { - stmts[i] = v.(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, err := compiler.Compile(v) + if err != nil { + return nil, errors.Trace(err) } + stmts[i] = stm } return stmts, nil } @@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express return nil, nil, nil } sm := sms[0] - return sm.(stmt.Statement), l.ParamList, nil + compiler := &optimizer.Compiler{} + statement, err := compiler.Compile(sm) + if err != nil { + return nil, nil, errors.Trace(err) + } + return statement, compiler.ParamMarkers(), nil } func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { diff --git a/tidb_test.go b/tidb_test.go index 34bb2620ca..f635867840 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) { // conflict mustExecSQL(c, se1, "begin") rs, err := exec(c, se1, "select * from t where c1=11 for update") + c.Assert(err, IsNil) _, err = rs.Rows(-1, 0) mustExecSQL(c, se2, "begin")