diff --git a/ast/ast.go b/ast/ast.go index 0e9ab19893..c38c8bc32d 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -26,6 +26,11 @@ type Node interface { // Accept accepts Visitor to visit itself. // The returned node should replace original node. // ok returns false to stop visiting. + // + // Implementation of this method should first call visitor.Enter, + // assign the returned node to its method receiver, if skipChildren returns true, + // children should be skipped. Otherwise, call its children in particular order that + // later elements depends on former elements. Finally, return visitor.Leave. Accept(v Visitor) (node Node, ok bool) // Text returns the original text of the element. Text() string @@ -44,6 +49,10 @@ type ExprNode interface { SetType(tp *types.FieldType) // GetType gets the evaluation type of the expression. GetType() *types.FieldType + // SetValue sets value to the expression. + SetValue(val interface{}) + // GetValue gets value of the expression. + GetValue() interface{} } // FuncNode represents function call expression node. @@ -93,11 +102,12 @@ type ResultSetNode interface { // Visitor visits a Node. type Visitor interface { // VisitEnter is called before children nodes is visited. + // The returned node must be the same type as the input node n. // skipChildren returns true means children nodes should be skipped, // this is useful when work is done in Enter and there is no need to visit children. - // ok returns false to stop visiting. - Enter(n Node) (skipChildren bool, ok bool) + Enter(n Node) (node Node, skipChildren bool) // VisitLeave is called after children nodes has been visited. + // The returned node must be the same type as the input node n. // ok returns false to stop visiting. Leave(n Node) (node Node, ok bool) } diff --git a/ast/base.go b/ast/base.go index 50e9ef030c..877fa80b51 100644 --- a/ast/base.go +++ b/ast/base.go @@ -62,7 +62,7 @@ func (dn *dmlNode) dmlStatement() {} // Expression implementations should embed it in. type exprNode struct { node - tp *types.FieldType + types.DataItem } // IsStatic implements Expression interface. @@ -72,12 +72,22 @@ func (en *exprNode) IsStatic() bool { // SetType implements Expression interface. func (en *exprNode) SetType(tp *types.FieldType) { - en.tp = tp + en.Type = tp } // GetType implements Expression interface. func (en *exprNode) GetType() *types.FieldType { - return en.tp + return en.Type +} + +// SetValue implements Expression interface. +func (en *exprNode) SetValue(val interface{}) { + en.Data = val +} + +// GetValue implements Expression interface. +func (en *exprNode) GetValue() interface{} { + return en.Data } type funcNode struct { diff --git a/ast/cloner.go b/ast/cloner.go new file mode 100644 index 0000000000..d3d8d282e5 --- /dev/null +++ b/ast/cloner.go @@ -0,0 +1,173 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package ast + +import "fmt" + +// Cloner is an ast visitor that clones a node. +type Cloner struct { +} + +// Enter implements Visitor Enter interface. +func (c *Cloner) Enter(node Node) (Node, bool) { + return copyStruct(node), false +} + +// Leave implements Visitor Leave interface. +func (c *Cloner) Leave(in Node) (out Node, ok bool) { + return in, true +} + +// copyStruct copies a node's struct value, if the struct has slice member, +// make a new slice and copy old slice value to new slice. +func copyStruct(in Node) (out Node) { + switch v := in.(type) { + case *ValueExpr: + nv := *v + out = &nv + case *BetweenExpr: + nv := *v + out = &nv + case *BinaryOperationExpr: + nv := *v + out = &nv + case *WhenClause: + nv := *v + out = &nv + case *CaseExpr: + nv := *v + nv.WhenClauses = make([]*WhenClause, len(v.WhenClauses)) + copy(nv.WhenClauses, v.WhenClauses) + out = &nv + case *SubqueryExpr: + nv := *v + out = &nv + case *CompareSubqueryExpr: + nv := *v + out = &nv + case *ColumnName: + nv := *v + out = &nv + case *ColumnNameExpr: + nv := *v + out = &nv + case *DefaultExpr: + nv := *v + out = &nv + case *IdentifierExpr: + nv := *v + out = &nv + case *ExistsSubqueryExpr: + nv := *v + out = &nv + case *PatternInExpr: + nv := *v + nv.List = make([]ExprNode, len(v.List)) + copy(nv.List, v.List) + out = &nv + case *IsNullExpr: + nv := *v + out = &nv + case *IsTruthExpr: + nv := *v + out = &nv + case *PatternLikeExpr: + nv := *v + out = &nv + case *ParamMarkerExpr: + nv := *v + out = &nv + case *ParenthesesExpr: + nv := *v + out = &nv + case *PositionExpr: + nv := *v + out = &nv + case *PatternRegexpExpr: + nv := *v + out = &nv + case *RowExpr: + nv := *v + nv.Values = make([]ExprNode, len(v.Values)) + copy(nv.Values, v.Values) + out = &nv + case *UnaryOperationExpr: + nv := *v + out = &nv + case *ValuesExpr: + nv := *v + out = &nv + case *VariableExpr: + nv := *v + out = &nv + case *Join: + nv := *v + out = &nv + case *TableName: + nv := *v + out = &nv + case *TableSource: + nv := *v + out = &nv + case *OnCondition: + nv := *v + out = &nv + case *WildCardField: + nv := *v + out = &nv + case *SelectField: + nv := *v + out = &nv + case *FieldList: + nv := *v + nv.Fields = make([]*SelectField, len(v.Fields)) + copy(nv.Fields, v.Fields) + out = &nv + case *TableRefsClause: + nv := *v + out = &nv + case *ByItem: + nv := *v + out = &nv + case *GroupByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *HavingClause: + nv := *v + out = &nv + case *OrderByClause: + nv := *v + nv.Items = make([]*ByItem, len(v.Items)) + copy(nv.Items, v.Items) + out = &nv + case *SelectStmt: + nv := *v + out = &nv + case *UnionClause: + nv := *v + out = &nv + case *UnionStmt: + nv := *v + nv.Selects = make([]*SelectStmt, len(v.Selects)) + copy(nv.Selects, v.Selects) + out = &nv + default: + // We currently only handle expression and select statement. + // Will add more when we need to. + panic("unknown ast Node type " + fmt.Sprintf("%T", v)) + } + return +} diff --git a/ast/cloner_test.go b/ast/cloner_test.go new file mode 100644 index 0000000000..f6f567205b --- /dev/null +++ b/ast/cloner_test.go @@ -0,0 +1,40 @@ +package ast + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb/parser/opcode" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testClonerSuite{}) + +type testClonerSuite struct { +} + +func (ts *testClonerSuite) TestCloner(c *C) { + cloner := &Cloner{} + + a := &UnaryOperationExpr{ + Op: opcode.Not, + V: &UnaryOperationExpr{V: NewValueExpr(true)}, + } + + b, ok := a.Accept(cloner) + c.Assert(ok, IsTrue) + a1 := a.V + b1 := b.(*UnaryOperationExpr).V + c.Assert(a1, Not(Equals), b1) + a2 := a1.(*UnaryOperationExpr).V + b2 := b1.(*UnaryOperationExpr).V + c.Assert(a2, Not(Equals), b2) + a3 := a2.(*ValueExpr) + b3 := b2.(*ValueExpr) + c.Assert(a3, Not(Equals), b3) + c.Assert(a3.GetValue(), Equals, true) + c.Assert(b3.GetValue(), Equals, true) +} diff --git a/ast/ddl.go b/ast/ddl.go index 0e43b69552..4f655d6f01 100644 --- a/ast/ddl.go +++ b/ast/ddl.go @@ -70,11 +70,13 @@ type CreateDatabaseStmt struct { } // Accept implements Node Accept interface. -func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cd) + n = newNod.(*CreateDatabaseStmt) + return v.Leave(n) } // DropDatabaseStmt is a statement to drop a database and all tables in the database. @@ -87,11 +89,13 @@ type DropDatabaseStmt struct { } // Accept implements Node Accept interface. -func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dd); skipChildren { - return dd, ok +func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(dd) + n = newNod.(*DropDatabaseStmt) + return v.Leave(n) } // IndexColName is used for parsing index column name from SQL. @@ -103,16 +107,18 @@ type IndexColName struct { } // Accept implements Node Accept interface. -func (ic *IndexColName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ic); skipChildren { - return ic, ok +func (n *IndexColName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ic.Column.Accept(v) + n = newNod.(*IndexColName) + node, ok := n.Column.Accept(v) if !ok { - return ic, false + return n, false } - ic.Column = node.(*ColumnName) - return v.Leave(ic) + n.Column = node.(*ColumnName) + return v.Leave(n) } // ReferenceDef is used for parsing foreign key reference option from SQL. @@ -125,23 +131,25 @@ type ReferenceDef struct { } // Accept implements Node Accept interface. -func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rd); skipChildren { - return rd, ok +func (n *ReferenceDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := rd.Table.Accept(v) + n = newNod.(*ReferenceDef) + node, ok := n.Table.Accept(v) if !ok { - return rd, false + return n, false } - rd.Table = node.(*TableName) - for i, val := range rd.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return rd, false + return n, false } - rd.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(rd) + return v.Leave(n) } // ColumnOptionType is the type for ColumnOption. @@ -175,18 +183,20 @@ type ColumnOption struct { } // Accept implements Node Accept interface. -func (co *ColumnOption) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(co); skipChildren { - return co, ok +func (n *ColumnOption) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if co.Expr != nil { - node, ok := co.Expr.Accept(v) + n = newNod.(*ColumnOption) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return co, false + return n, false } - co.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(co) + return v.Leave(n) } // ConstraintType is the type for Constraint. @@ -220,25 +230,27 @@ type Constraint struct { } // Accept implements Node Accept interface. -func (tc *Constraint) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tc); skipChildren { - return tc, ok +func (n *Constraint) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range tc.Keys { + n = newNod.(*Constraint) + for i, val := range n.Keys { node, ok := val.Accept(v) if !ok { - return tc, false + return n, false } - tc.Keys[i] = node.(*IndexColName) + n.Keys[i] = node.(*IndexColName) } - if tc.Refer != nil { - node, ok := tc.Refer.Accept(v) + if n.Refer != nil { + node, ok := n.Refer.Accept(v) if !ok { - return tc, false + return n, false } - tc.Refer = node.(*ReferenceDef) + n.Refer = node.(*ReferenceDef) } - return v.Leave(tc) + return v.Leave(n) } // ColumnDef is used for parsing column definition from SQL. @@ -251,23 +263,25 @@ type ColumnDef struct { } // Accept implements Node Accept interface. -func (cd *ColumnDef) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cd); skipChildren { - return cd, ok +func (n *ColumnDef) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := cd.Name.Accept(v) + n = newNod.(*ColumnDef) + node, ok := n.Name.Accept(v) if !ok { - return cd, false + return n, false } - cd.Name = node.(*ColumnName) - for i, val := range cd.Options { + n.Name = node.(*ColumnName) + for i, val := range n.Options { node, ok := val.Accept(v) if !ok { - return cd, false + return n, false } - cd.Options[i] = node.(*ColumnOption) + n.Options[i] = node.(*ColumnOption) } - return v.Leave(cd) + return v.Leave(n) } // CreateTableStmt is a statement to create a table. @@ -283,30 +297,32 @@ type CreateTableStmt struct { } // Accept implements Node Accept interface. -func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ct); skipChildren { - return ct, ok +func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ct.Table.Accept(v) + n = newNod.(*CreateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return ct, false + return n, false } - ct.Table = node.(*TableName) - for i, val := range ct.Cols { + n.Table = node.(*TableName) + for i, val := range n.Cols { node, ok = val.Accept(v) if !ok { - return ct, false + return n, false } - ct.Cols[i] = node.(*ColumnDef) + n.Cols[i] = node.(*ColumnDef) } - for i, val := range ct.Constraints { + for i, val := range n.Constraints { node, ok = val.Accept(v) if !ok { - return ct, false + return n, false } - ct.Constraints[i] = node.(*Constraint) + n.Constraints[i] = node.(*Constraint) } - return v.Leave(ct) + return v.Leave(n) } // DropTableStmt is a statement to drop one or more tables. @@ -319,18 +335,20 @@ type DropTableStmt struct { } // Accept implements Node Accept interface. -func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(dt); skipChildren { - return dt, ok +func (n *DropTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range dt.Tables { + n = newNod.(*DropTableStmt) + for i, val := range n.Tables { node, ok := val.Accept(v) if !ok { - return dt, false + return n, false } - dt.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - return v.Leave(dt) + return v.Leave(n) } // CreateIndexStmt is a statement to create an index. @@ -345,23 +363,25 @@ type CreateIndexStmt struct { } // Accept implements Node Accept interface. -func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ci); skipChildren { - return ci, ok +func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ci.Table.Accept(v) + n = newNod.(*CreateIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return ci, false + return n, false } - ci.Table = node.(*TableName) - for i, val := range ci.IndexColNames { + n.Table = node.(*TableName) + for i, val := range n.IndexColNames { node, ok = val.Accept(v) if !ok { - return ci, false + return n, false } - ci.IndexColNames[i] = node.(*IndexColName) + n.IndexColNames[i] = node.(*IndexColName) } - return v.Leave(ci) + return v.Leave(n) } // DropIndexStmt is a statement to drop the index. @@ -375,16 +395,18 @@ type DropIndexStmt struct { } // Accept implements Node Accept interface. -func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(di); skipChildren { - return di, ok +func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := di.Table.Accept(v) + n = newNod.(*DropIndexStmt) + node, ok := n.Table.Accept(v) if !ok { - return di, false + return n, false } - di.Table = node.(*TableName) - return v.Leave(di) + n.Table = node.(*TableName) + return v.Leave(n) } // TableOptionType is the type for TableOption @@ -435,18 +457,20 @@ type ColumnPosition struct { } // Accept implements Node Accept interface. -func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cp); skipChildren { - return cp, ok +func (n *ColumnPosition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if cp.RelativeColumn != nil { - node, ok := cp.RelativeColumn.Accept(v) + n = newNod.(*ColumnPosition) + if n.RelativeColumn != nil { + node, ok := n.RelativeColumn.Accept(v) if !ok { - return cp, false + return n, false } - cp.RelativeColumn = node.(*ColumnName) + n.RelativeColumn = node.(*ColumnName) } - return v.Leave(cp) + return v.Leave(n) } // AlterTableType is the type for AlterTableSpec. @@ -474,44 +498,46 @@ type AlterTableSpec struct { Constraint *Constraint Options []*TableOption Column *ColumnDef - ColumnName *ColumnName + DropColumn *ColumnName Position *ColumnPosition } // Accept implements Node Accept interface. -func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if as.Constraint != nil { - node, ok := as.Constraint.Accept(v) + n = newNod.(*AlterTableSpec) + if n.Constraint != nil { + node, ok := n.Constraint.Accept(v) if !ok { - return as, false + return n, false } - as.Constraint = node.(*Constraint) + n.Constraint = node.(*Constraint) } - if as.Column != nil { - node, ok := as.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return as, false + return n, false } - as.Column = node.(*ColumnDef) + n.Column = node.(*ColumnDef) } - if as.ColumnName != nil { - node, ok := as.ColumnName.Accept(v) + if n.DropColumn != nil { + node, ok := n.DropColumn.Accept(v) if !ok { - return as, false + return n, false } - as.ColumnName = node.(*ColumnName) + n.DropColumn = node.(*ColumnName) } - if as.Position != nil { - node, ok := as.Position.Accept(v) + if n.Position != nil { + node, ok := n.Position.Accept(v) if !ok { - return as, false + return n, false } - as.Position = node.(*ColumnPosition) + n.Position = node.(*ColumnPosition) } - return v.Leave(as) + return v.Leave(n) } // AlterTableStmt is a statement to change the structure of a table. @@ -524,23 +550,25 @@ type AlterTableStmt struct { } // Accept implements Node Accept interface. -func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(at); skipChildren { - return at, ok +func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := at.Table.Accept(v) + n = newNod.(*AlterTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return at, false + return n, false } - at.Table = node.(*TableName) - for i, val := range at.Specs { + n.Table = node.(*TableName) + for i, val := range n.Specs { node, ok = val.Accept(v) if !ok { - return at, false + return n, false } - at.Specs[i] = node.(*AlterTableSpec) + n.Specs[i] = node.(*AlterTableSpec) } - return v.Leave(at) + return v.Leave(n) } // TruncateTableStmt is a statement to empty a table completely. @@ -552,14 +580,16 @@ type TruncateTableStmt struct { } // Accept implements Node Accept interface. -func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Table.Accept(v) + n = newNod.(*TruncateTableStmt) + node, ok := n.Table.Accept(v) if !ok { - return ts, false + return n, false } - ts.Table = node.(*TableName) - return v.Leave(ts) + n.Table = node.(*TableName) + return v.Leave(n) } diff --git a/ast/dml.go b/ast/dml.go index b8ada81d99..e15b5f83bb 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -22,6 +22,7 @@ var ( _ DMLNode = &DeleteStmt{} _ DMLNode = &UpdateStmt{} _ DMLNode = &SelectStmt{} + _ DMLNode = &UnionStmt{} _ Node = &Join{} _ Node = &TableName{} _ Node = &TableSource{} @@ -55,34 +56,36 @@ type Join struct { // Tp represents join type. Tp JoinType // On represents join on condition. - On ExprNode + On *OnCondition } // Accept implements Node Accept interface. -func (j *Join) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(j); skipChildren { - return j, ok +func (n *Join) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := j.Left.Accept(v) + n = newNod.(*Join) + node, ok := n.Left.Accept(v) if !ok { - return j, false + return n, false } - j.Left = node.(ResultSetNode) - if j.Right != nil { - node, ok = j.Right.Accept(v) + n.Left = node.(ResultSetNode) + if n.Right != nil { + node, ok = n.Right.Accept(v) if !ok { - return j, false + return n, false } - j.Right = node.(ResultSetNode) + n.Right = node.(ResultSetNode) } - if j.On != nil { - node, ok = j.On.Accept(v) + if n.On != nil { + node, ok = n.On.Accept(v) if !ok { - return j, false + return n, false } - j.On = node.(ExprNode) + n.On = node.(*OnCondition) } - return v.Leave(j) + return v.Leave(n) } // TableName represents a table name. @@ -98,11 +101,13 @@ type TableName struct { } // Accept implements Node Accept interface. -func (tr *TableName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tr); skipChildren { - return tr, ok +func (n *TableName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(tr) + n = newNod.(*TableName) + return v.Leave(n) } // TableSource represents table source with a name. @@ -110,7 +115,7 @@ type TableSource struct { node // Source is the source of the data, can be a TableName, - // a SubQuery, or a JoinNode. + // a SelectStmt, a UnionStmt, or a JoinNode. Source ResultSetNode // AsName is the as name of the table source. @@ -118,47 +123,50 @@ type TableSource struct { } // Accept implements Node Accept interface. -func (ts *TableSource) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ts); skipChildren { - return ts, ok +func (n *TableSource) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ts.Source.Accept(v) + n = newNod.(*TableSource) + node, ok := n.Source.Accept(v) if !ok { - return ts, false + return n, false } - ts.Source = node.(ResultSetNode) - return v.Leave(ts) + n.Source = node.(ResultSetNode) + return v.Leave(n) } -// SetResultFields implements ResultSet interface. -func (ts *TableSource) SetResultFields(rfs []*ResultField) { - ts.Source.SetResultFields(rfs) -} - -// GetResultFields implements ResultSet interface. -func (ts *TableSource) GetResultFields() []*ResultField { - return ts.Source.GetResultFields() -} - -// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. -type UnionClause struct { +// OnCondition represetns JOIN on condition. +type OnCondition struct { node - Distinct bool - Select *SelectStmt + Expr ExprNode } // Accept implements Node Accept interface. -func (uc *UnionClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(uc); skipChildren { - return uc, ok +func (n *OnCondition) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := uc.Select.Accept(v) + n = newNod.(*OnCondition) + node, ok := n.Expr.Accept(v) if !ok { - return uc, false + return n, false } - uc.Select = node.(*SelectStmt) - return v.Leave(uc) + n.Expr = node.(ExprNode) + return v.Leave(n) +} + +// SetResultFields implements ResultSet interface. +func (n *TableSource) SetResultFields(rfs []*ResultField) { + n.Source.SetResultFields(rfs) +} + +// GetResultFields implements ResultSet interface. +func (n *TableSource) GetResultFields() []*ResultField { + return n.Source.GetResultFields() } // SelectLockType is the lock type for SelectStmt. @@ -175,22 +183,18 @@ const ( type WildCardField struct { node - Table *TableName + Table model.CIStr + Schema model.CIStr } // Accept implements Node Accept interface. -func (wf *WildCardField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(wf); skipChildren { - return wf, ok +func (n *WildCardField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if wf.Table != nil { - node, ok := wf.Table.Accept(v) - if !ok { - return wf, false - } - wf.Table = node.(*TableName) - } - return v.Leave(wf) + n = newNod.(*WildCardField) + return v.Leave(n) } // SelectField represents fields in select statement. @@ -199,6 +203,8 @@ func (wf *WildCardField) Accept(v Visitor) (Node, bool) { type SelectField struct { node + // Offset is used to get original text. + Offset int // If WildCard is not nil, Expr will be nil. WildCard *WildCardField // If Expr is not nil, WildCard will be nil. @@ -208,22 +214,70 @@ type SelectField struct { } // Accept implements Node Accept interface. -func (sf *SelectField) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (n *SelectField) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if sf.Expr != nil { - node, ok := sf.Expr.Accept(v) + n = newNod.(*SelectField) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return sf, false + return n, false } - sf.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(sf) + return v.Leave(n) } -// OrderByItem represents a single order by item. -type OrderByItem struct { +// FieldList represents field list in select statement. +type FieldList struct { + node + + Fields []*SelectField +} + +// Accept implements Node Accept interface. +func (n *FieldList) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*FieldList) + for i, val := range n.Fields { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Fields[i] = node.(*SelectField) + } + return v.Leave(n) +} + +// TableRefsClause represents table references clause in dml statement. +type TableRefsClause struct { + node + + TableRefs *Join +} + +// Accept implements Node Accept interface. +func (n *TableRefsClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*TableRefsClause) + node, ok := n.TableRefs.Accept(v) + if !ok { + return n, false + } + n.TableRefs = node.(*Join) + return v.Leave(n) +} + +// ByItem represents an item in order by or group by. +type ByItem struct { node Expr ExprNode @@ -231,131 +285,244 @@ type OrderByItem struct { } // Accept implements Node Accept interface. -func (ob *OrderByItem) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ob); skipChildren { - return ob, ok +func (n *ByItem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ob.Expr.Accept(v) + n = newNod.(*ByItem) + node, ok := n.Expr.Accept(v) if !ok { - return ob, false + return n, false } - ob.Expr = node.(ExprNode) - return v.Leave(ob) + n.Expr = node.(ExprNode) + return v.Leave(n) +} + +// GroupByClause represents group by clause. +type GroupByClause struct { + node + Items []*ByItem +} + +// Accept implements Node Accept interface. +func (n *GroupByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*GroupByClause) + for i, val := range n.Items { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Items[i] = node.(*ByItem) + } + return v.Leave(n) +} + +// HavingClause represents having clause. +type HavingClause struct { + node + Expr ExprNode +} + +// Accept implements Node Accept interface. +func (n *HavingClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*HavingClause) + node, ok := n.Expr.Accept(v) + if !ok { + return n, false + } + n.Expr = node.(ExprNode) + return v.Leave(n) +} + +// OrderByClause represents order by clause. +type OrderByClause struct { + node + Items []*ByItem + ForUnion bool +} + +// Accept implements Node Accept interface. +func (n *OrderByClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*OrderByClause) + for i, val := range n.Items { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Items[i] = node.(*ByItem) + } + return v.Leave(n) } // SelectStmt represents the select query node. +// See: https://dev.mysql.com/doc/refman/5.7/en/select.html type SelectStmt struct { dmlNode resultSetNode // Distinct represents if the select has distinct option. Distinct bool - // Fields is the select expression list. - Fields []*SelectField // From is the from clause of the query. - From *Join + From *TableRefsClause // Where is the where clause in select statement. Where ExprNode + // Fields is the select expression list. + Fields *FieldList // GroupBy is the group by expression list. - GroupBy []ExprNode + GroupBy *GroupByClause // Having is the having condition. - Having ExprNode - // OrderBy is the odering expression list. - OrderBy []*OrderByItem + Having *HavingClause + // OrderBy is the ordering expression list. + OrderBy *OrderByClause // Limit is the limit clause. Limit *Limit // Lock is the lock type LockTp SelectLockType - - // Union clauses. - Unions []*UnionClause - // Order by for union select. - UnionOrderBy []*OrderByItem - // Limit for union select. - UnionLimit *Limit } // Accept implements Node Accept interface. -func (sn *SelectStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sn); skipChildren { - return sn, ok +func (n *SelectStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range sn.Fields { - node, ok := val.Accept(v) + n = newNod.(*SelectStmt) + + if n.From != nil { + node, ok := n.From.Accept(v) if !ok { - return sn, false + return n, false } - sn.Fields[i] = node.(*SelectField) - } - if sn.From != nil { - node, ok := sn.From.Accept(v) - if !ok { - return sn, false - } - sn.From = node.(*Join) + n.From = node.(*TableRefsClause) } - if sn.Where != nil { - node, ok := sn.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return sn, false + return n, false } - sn.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - for i, val := range sn.GroupBy { - node, ok := val.Accept(v) + if n.Fields != nil { + node, ok := n.Fields.Accept(v) if !ok { - return sn, false + return n, false } - sn.GroupBy[i] = node.(ExprNode) - } - if sn.Having != nil { - node, ok := sn.Having.Accept(v) - if !ok { - return sn, false - } - sn.Having = node.(ExprNode) + n.Fields = node.(*FieldList) } - for i, val := range sn.OrderBy { - node, ok := val.Accept(v) + if n.GroupBy != nil { + node, ok := n.GroupBy.Accept(v) if !ok { - return sn, false + return n, false } - sn.OrderBy[i] = node.(*OrderByItem) + n.GroupBy = node.(*GroupByClause) } - if sn.Limit != nil { - node, ok := sn.Limit.Accept(v) + if n.Having != nil { + node, ok := n.Having.Accept(v) if !ok { - return sn, false + return n, false } - sn.Limit = node.(*Limit) + n.Having = node.(*HavingClause) } - for i, val := range sn.Unions { + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) + if !ok { + return n, false + } + n.OrderBy = node.(*OrderByClause) + } + + if n.Limit != nil { + node, ok := n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) + } + return v.Leave(n) +} + +// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause. +type UnionClause struct { + node + + Distinct bool + Select *SelectStmt +} + +// Accept implements Node Accept interface. +func (n *UnionClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*UnionClause) + node, ok := n.Select.Accept(v) + if !ok { + return n, false + } + n.Select = node.(*SelectStmt) + return v.Leave(n) +} + +// UnionStmt represents "union statement" +// See: https://dev.mysql.com/doc/refman/5.7/en/union.html +type UnionStmt struct { + dmlNode + resultSetNode + + Distinct bool + Selects []*SelectStmt + OrderBy *OrderByClause + Limit *Limit +} + +// Accept implements Node Accept interface. +func (n *UnionStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*UnionStmt) + for i, val := range n.Selects { node, ok := val.Accept(v) if !ok { - return sn, false + return n, false } - sn.Unions[i] = node.(*UnionClause) + n.Selects[i] = node.(*SelectStmt) } - for i, val := range sn.UnionOrderBy { - node, ok := val.Accept(v) + if n.OrderBy != nil { + node, ok := n.OrderBy.Accept(v) if !ok { - return sn, false + return n, false } - sn.UnionOrderBy[i] = node.(*OrderByItem) + n.OrderBy = node.(*OrderByClause) } - if sn.UnionLimit != nil { - node, ok := sn.UnionLimit.Accept(v) + if n.Limit != nil { + node, ok := n.Limit.Accept(v) if !ok { - return sn, false + return n, false } - sn.UnionLimit = node.(*Limit) + n.Limit = node.(*Limit) } - return v.Leave(sn) + return v.Leave(n) } // Assignment is the expression for assignment, like a = 1. @@ -368,21 +535,23 @@ type Assignment struct { } // Accept implements Node Accept interface. -func (as *Assignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(as); skipChildren { - return as, ok +func (n *Assignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := as.Column.Accept(v) + n = newNod.(*Assignment) + node, ok := n.Column.Accept(v) if !ok { - return as, false + return n, false } - as.Column = node.(*ColumnName) - node, ok = as.Expr.Accept(v) + n.Column = node.(*ColumnName) + node, ok = n.Expr.Accept(v) if !ok { - return as, false + return n, false } - as.Expr = node.(ExprNode) - return v.Leave(as) + n.Expr = node.(ExprNode) + return v.Leave(n) } // Priority const values. @@ -399,58 +568,68 @@ const ( type InsertStmt struct { dmlNode + Replace bool + Table *TableRefsClause Columns []*ColumnName Lists [][]ExprNode - Table *TableName Setlist []*Assignment Priority int OnDuplicate []*Assignment - Select *SelectStmt + Select ResultSetNode } // Accept implements Node Accept interface. -func (in *InsertStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(in); skipChildren { - return in, ok +func (n *InsertStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range in.Columns { + n = newNod.(*InsertStmt) + + if n.Select != nil { + node, ok := n.Select.Accept(v) + if !ok { + return n, false + } + n.Select = node.(ResultSetNode) + } + node, ok := n.Table.Accept(v) + if !ok { + return n, false + } + n.Table = node.(*TableRefsClause) + + for i, val := range n.Columns { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.Columns[i] = node.(*ColumnName) + n.Columns[i] = node.(*ColumnName) } - for i, list := range in.Lists { + for i, list := range n.Lists { for j, val := range list { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.Lists[i][j] = node.(ExprNode) + n.Lists[i][j] = node.(ExprNode) } } - for i, val := range in.Setlist { + for i, val := range n.Setlist { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.Setlist[i] = node.(*Assignment) + n.Setlist[i] = node.(*Assignment) } - for i, val := range in.OnDuplicate { + for i, val := range n.OnDuplicate { node, ok := val.Accept(v) if !ok { - return in, false + return n, false } - in.OnDuplicate[i] = node.(*Assignment) + n.OnDuplicate[i] = node.(*Assignment) } - if in.Select != nil { - node, ok := in.Select.Accept(v) - if !ok { - return in, false - } - in.Select = node.(*SelectStmt) - } - return v.Leave(in) + return v.Leave(n) } // DeleteStmt is a statement to delete rows from table. @@ -458,10 +637,12 @@ func (in *InsertStmt) Accept(v Visitor) (Node, bool) { type DeleteStmt struct { dmlNode - TableRefs *Join + // Used in both single table and multiple table delete statement. + TableRefs *TableRefsClause + // Only used in multiple table delete statement. Tables []*TableName Where ExprNode - Order []*OrderByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -471,47 +652,49 @@ type DeleteStmt struct { } // Accept implements Node Accept interface. -func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(de); skipChildren { - return de, ok +func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } + n = newNod.(*DeleteStmt) - node, ok := de.TableRefs.Accept(v) + node, ok := n.TableRefs.Accept(v) if !ok { - return de, false + return n, false } - de.TableRefs = node.(*Join) + n.TableRefs = node.(*TableRefsClause) - for i, val := range de.Tables { + for i, val := range n.Tables { node, ok = val.Accept(v) if !ok { - return de, false + return n, false } - de.Tables[i] = node.(*TableName) + n.Tables[i] = node.(*TableName) } - if de.Where != nil { - node, ok = de.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return de, false + return n, false } - de.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - - for i, val := range de.Order { - node, ok = val.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return de, false + return n, false } - de.Order[i] = node.(*OrderByItem) + n.Order = node.(*OrderByClause) } - - node, ok = de.Limit.Accept(v) - if !ok { - return de, false + if n.Limit != nil { + node, ok = n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) } - de.Limit = node.(*Limit) - return v.Leave(de) + return v.Leave(n) } // UpdateStmt is a statement to update columns of existing rows in tables with new values. @@ -519,10 +702,10 @@ func (de *DeleteStmt) Accept(v Visitor) (Node, bool) { type UpdateStmt struct { dmlNode - TableRefs *Join + TableRefs *TableRefsClause List []*Assignment Where ExprNode - Order []*OrderByItem + Order *OrderByClause Limit *Limit LowPriority bool Ignore bool @@ -530,43 +713,46 @@ type UpdateStmt struct { } // Accept implements Node Accept interface. -func (up *UpdateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(up); skipChildren { - return up, ok +func (n *UpdateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := up.TableRefs.Accept(v) + n = newNod.(*UpdateStmt) + node, ok := n.TableRefs.Accept(v) if !ok { - return up, false + return n, false } - up.TableRefs = node.(*Join) - for i, val := range up.List { + n.TableRefs = node.(*TableRefsClause) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return up, false + return n, false } - up.List[i] = node.(*Assignment) + n.List[i] = node.(*Assignment) } - if up.Where != nil { - node, ok = up.Where.Accept(v) + if n.Where != nil { + node, ok = n.Where.Accept(v) if !ok { - return up, false + return n, false } - up.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - - for i, val := range up.Order { - node, ok = val.Accept(v) + if n.Order != nil { + node, ok = n.Order.Accept(v) if !ok { - return up, false + return n, false } - up.Order[i] = node.(*OrderByItem) + n.Order = node.(*OrderByClause) } - node, ok = up.Limit.Accept(v) - if !ok { - return up, false + if n.Limit != nil { + node, ok = n.Limit.Accept(v) + if !ok { + return n, false + } + n.Limit = node.(*Limit) } - up.Limit = node.(*Limit) - return v.Leave(up) + return v.Leave(n) } // Limit is the limit clause. @@ -578,9 +764,11 @@ type Limit struct { } // Accept implements Node Accept interface. -func (l *Limit) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(l); skipChildren { - return l, ok +func (n *Limit) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(l) + n = newNod.(*Limit) + return v.Leave(n) } diff --git a/ast/expressions.go b/ast/expressions.go index 558f1a019c..c3278bd331 100644 --- a/ast/expressions.go +++ b/ast/expressions.go @@ -14,8 +14,11 @@ package ast import ( + "fmt" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" ) var ( @@ -26,6 +29,7 @@ var ( _ ExprNode = &CaseExpr{} _ ExprNode = &SubqueryExpr{} _ ExprNode = &CompareSubqueryExpr{} + _ Node = &ColumnName{} _ ExprNode = &ColumnNameExpr{} _ ExprNode = &DefaultExpr{} _ ExprNode = &IdentifierExpr{} @@ -47,21 +51,58 @@ var ( // ValueExpr is the simple value expression. type ValueExpr struct { exprNode - // Val is the literal value. - Val interface{} +} + +// NewValueExpr creates a ValueExpr with value, and sets default field type. +func NewValueExpr(value interface{}) *ValueExpr { + ve := &ValueExpr{} + ve.Data = types.RawData(value) + // TODO: make it more precise. + switch value.(type) { + case nil: + ve.Type = types.NewFieldType(mysql.TypeNull) + case bool, int64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + case uint64: + ve.Type = types.NewFieldType(mysql.TypeLonglong) + ve.Type.Flag |= mysql.UnsignedFlag + case string, UnquoteString: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = mysql.DefaultCharset + ve.Type.Collate = mysql.DefaultCollationName + case float64: + ve.Type = types.NewFieldType(mysql.TypeDouble) + case []byte: + ve.Type = types.NewFieldType(mysql.TypeBlob) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + case mysql.Bit: + ve.Type = types.NewFieldType(mysql.TypeBit) + case mysql.Hex: + ve.Type = types.NewFieldType(mysql.TypeVarchar) + ve.Type.Charset = "binary" + ve.Type.Collate = "binary" + case *types.DataItem: + ve.Type = value.(*types.DataItem).Type + default: + panic(fmt.Sprintf("illegal literal value type:%T", value)) + } + return ve } // IsStatic implements ExprNode interface. -func (val *ValueExpr) IsStatic() bool { +func (n *ValueExpr) IsStatic() bool { return true } // Accept implements Node interface. -func (val *ValueExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(val); skipChildren { - return val, ok +func (n *ValueExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(val) + n = newNod.(*ValueExpr) + return v.Leave(n) } // BetweenExpr is for "between and" or "not between and" expression. @@ -78,35 +119,37 @@ type BetweenExpr struct { } // Accept implements Node interface. -func (b *BetweenExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(b); skipChildren { - return b, ok +func (n *BetweenExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } + n = newNod.(*BetweenExpr) - node, ok := b.Expr.Accept(v) + node, ok := n.Expr.Accept(v) if !ok { - return b, false + return n, false } - b.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = b.Left.Accept(v) + node, ok = n.Left.Accept(v) if !ok { - return b, false + return n, false } - b.Left = node.(ExprNode) + n.Left = node.(ExprNode) - node, ok = b.Right.Accept(v) + node, ok = n.Right.Accept(v) if !ok { - return b, false + return n, false } - b.Right = node.(ExprNode) + n.Right = node.(ExprNode) - return v.Leave(b) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (b *BetweenExpr) IsStatic() bool { - return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic() +func (n *BetweenExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Left.IsStatic() && n.Right.IsStatic() } // BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc. @@ -121,29 +164,31 @@ type BinaryOperationExpr struct { } // Accept implements Node interface. -func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(o); skipChildren { - return o, ok +func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } + n = newNod.(*BinaryOperationExpr) - node, ok := o.L.Accept(v) + node, ok := n.L.Accept(v) if !ok { - return o, false + return n, false } - o.L = node.(ExprNode) + n.L = node.(ExprNode) - node, ok = o.R.Accept(v) + node, ok = n.R.Accept(v) if !ok { - return o, false + return n, false } - o.R = node.(ExprNode) + n.R = node.(ExprNode) - return v.Leave(o) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (o *BinaryOperationExpr) IsStatic() bool { - return o.L.IsStatic() && o.R.IsStatic() +func (n *BinaryOperationExpr) IsStatic() bool { + return n.L.IsStatic() && n.R.IsStatic() } // WhenClause is the when clause in Case expression for "when condition then result". @@ -156,27 +201,29 @@ type WhenClause struct { } // Accept implements Node Accept interface. -func (w *WhenClause) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(w); skipChildren { - return w, ok +func (n *WhenClause) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := w.Expr.Accept(v) + n = newNod.(*WhenClause) + node, ok := n.Expr.Accept(v) if !ok { - return w, false + return n, false } - w.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) - node, ok = w.Result.Accept(v) + node, ok = n.Result.Accept(v) if !ok { - return w, false + return n, false } - w.Result = node.(ExprNode) - return v.Leave(w) + n.Result = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (w *WhenClause) IsStatic() bool { - return w.Expr.IsStatic() && w.Result.IsStatic() +func (n *WhenClause) IsStatic() bool { + return n.Expr.IsStatic() && n.Result.IsStatic() } // CaseExpr is the case expression. @@ -191,45 +238,47 @@ type CaseExpr struct { } // Accept implements Node Accept interface. -func (f *CaseExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (n *CaseExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if f.Value != nil { - node, ok := f.Value.Accept(v) + n = newNod.(*CaseExpr) + if n.Value != nil { + node, ok := n.Value.Accept(v) if !ok { - return f, false + return n, false } - f.Value = node.(ExprNode) + n.Value = node.(ExprNode) } - for i, val := range f.WhenClauses { + for i, val := range n.WhenClauses { node, ok := val.Accept(v) if !ok { - return f, false + return n, false } - f.WhenClauses[i] = node.(*WhenClause) + n.WhenClauses[i] = node.(*WhenClause) } - if f.ElseClause != nil { - node, ok := f.ElseClause.Accept(v) + if n.ElseClause != nil { + node, ok := n.ElseClause.Accept(v) if !ok { - return f, false + return n, false } - f.ElseClause = node.(ExprNode) + n.ElseClause = node.(ExprNode) } - return v.Leave(f) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (f *CaseExpr) IsStatic() bool { - if f.Value != nil && !f.Value.IsStatic() { +func (n *CaseExpr) IsStatic() bool { + if n.Value != nil && !n.Value.IsStatic() { return false } - for _, w := range f.WhenClauses { + for _, w := range n.WhenClauses { if !w.IsStatic() { return false } } - if f.ElseClause != nil && !f.ElseClause.IsStatic() { + if n.ElseClause != nil && !n.ElseClause.IsStatic() { return false } return true @@ -239,30 +288,32 @@ func (f *CaseExpr) IsStatic() bool { type SubqueryExpr struct { exprNode // Query is the query SelectNode. - Query *SelectStmt + Query ResultSetNode } // Accept implements Node Accept interface. -func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sq); skipChildren { - return sq, ok +func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := sq.Query.Accept(v) + n = newNod.(*SubqueryExpr) + node, ok := n.Query.Accept(v) if !ok { - return sq, false + return n, false } - sq.Query = node.(*SelectStmt) - return v.Leave(sq) + n.Query = node.(ResultSetNode) + return v.Leave(n) } // SetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) SetResultFields(rfs []*ResultField) { - sq.Query.SetResultFields(rfs) +func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) { + n.Query.SetResultFields(rfs) } // GetResultFields implements ResultSet interface. -func (sq *SubqueryExpr) GetResultFields() []*ResultField { - return sq.Query.GetResultFields() +func (n *SubqueryExpr) GetResultFields() []*ResultField { + return n.Query.GetResultFields() } // CompareSubqueryExpr is the expression for "expr cmp (select ...)". @@ -282,21 +333,23 @@ type CompareSubqueryExpr struct { } // Accept implements Node Accept interface. -func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := cs.L.Accept(v) + n = newNod.(*CompareSubqueryExpr) + node, ok := n.L.Accept(v) if !ok { - return cs, false + return n, false } - cs.L = node.(ExprNode) - node, ok = cs.R.Accept(v) + n.L = node.(ExprNode) + node, ok = n.R.Accept(v) if !ok { - return cs, false + return n, false } - cs.R = node.(*SubqueryExpr) - return v.Leave(cs) + n.R = node.(*SubqueryExpr) + return v.Leave(n) } // ColumnName represents column name. @@ -312,11 +365,13 @@ type ColumnName struct { } // Accept implements Node Accept interface. -func (cn *ColumnName) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cn); skipChildren { - return cn, ok +func (n *ColumnName) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cn) + n = newNod.(*ColumnName) + return v.Leave(n) } // ColumnNameExpr represents a column name expression. @@ -328,16 +383,18 @@ type ColumnNameExpr struct { } // Accept implements Node Accept interface. -func (cr *ColumnNameExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cr); skipChildren { - return cr, ok +func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := cr.Name.Accept(v) + n = newNod.(*ColumnNameExpr) + node, ok := n.Name.Accept(v) if !ok { - return cr, false + return n, false } - cr.Name = node.(*ColumnName) - return v.Leave(cr) + n.Name = node.(*ColumnName) + return v.Leave(n) } // DefaultExpr is the default expression using default value for a column. @@ -348,18 +405,20 @@ type DefaultExpr struct { } // Accept implements Node Accept interface. -func (d *DefaultExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(d); skipChildren { - return d, ok +func (n *DefaultExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if d.Name != nil { - node, ok := d.Name.Accept(v) + n = newNod.(*DefaultExpr) + if n.Name != nil { + node, ok := n.Name.Accept(v) if !ok { - return d, false + return n, false } - d.Name = node.(*ColumnName) + n.Name = node.(*ColumnName) } - return v.Leave(d) + return v.Leave(n) } // IdentifierExpr represents an identifier expression. @@ -370,11 +429,13 @@ type IdentifierExpr struct { } // Accept implements Node Accept interface. -func (i *IdentifierExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(i); skipChildren { - return i, ok +func (n *IdentifierExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(i) + n = newNod.(*IdentifierExpr) + return v.Leave(n) } // ExistsSubqueryExpr is the expression for "exists (select ...)". @@ -386,16 +447,18 @@ type ExistsSubqueryExpr struct { } // Accept implements Node Accept interface. -func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Sel.Accept(v) + n = newNod.(*ExistsSubqueryExpr) + node, ok := n.Sel.Accept(v) if !ok { - return es, false + return n, false } - es.Sel = node.(*SubqueryExpr) - return v.Leave(es) + n.Sel = node.(*SubqueryExpr) + return v.Leave(n) } // PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)". @@ -412,30 +475,32 @@ type PatternInExpr struct { } // Accept implements Node Accept interface. -func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pi); skipChildren { - return pi, ok +func (n *PatternInExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := pi.Expr.Accept(v) + n = newNod.(*PatternInExpr) + node, ok := n.Expr.Accept(v) if !ok { - return pi, false + return n, false } - pi.Expr = node.(ExprNode) - for i, val := range pi.List { + n.Expr = node.(ExprNode) + for i, val := range n.List { node, ok = val.Accept(v) if !ok { - return pi, false + return n, false } - pi.List[i] = node.(ExprNode) + n.List[i] = node.(ExprNode) } - if pi.Sel != nil { - node, ok = pi.Sel.Accept(v) + if n.Sel != nil { + node, ok = n.Sel.Accept(v) if !ok { - return pi, false + return n, false } - pi.Sel = node.(*SubqueryExpr) + n.Sel = node.(*SubqueryExpr) } - return v.Leave(pi) + return v.Leave(n) } // IsNullExpr is the expression for null check. @@ -448,21 +513,23 @@ type IsNullExpr struct { } // Accept implements Node Accept interface. -func (is *IsNullExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (n *IsNullExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + n = newNod.(*IsNullExpr) + node, ok := n.Expr.Accept(v) if !ok { - return is, false + return n, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsNullExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (n *IsNullExpr) IsStatic() bool { + return n.Expr.IsStatic() } // IsTruthExpr is the expression for true/false check. @@ -477,21 +544,23 @@ type IsTruthExpr struct { } // Accept implements Node Accept interface. -func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(is); skipChildren { - return is, ok +func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := is.Expr.Accept(v) + n = newNod.(*IsTruthExpr) + node, ok := n.Expr.Accept(v) if !ok { - return is, false + return n, false } - is.Expr = node.(ExprNode) - return v.Leave(is) + n.Expr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (is *IsTruthExpr) IsStatic() bool { - return is.Expr.IsStatic() +func (n *IsTruthExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PatternLikeExpr is the expression for like operator, e.g, expr like "%123%" @@ -508,40 +577,49 @@ type PatternLikeExpr struct { } // Accept implements Node Accept interface. -func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pl); skipChildren { - return pl, ok +func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := pl.Expr.Accept(v) - if !ok { - return pl, false + n = newNod.(*PatternLikeExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) + if !ok { + return n, false + } + n.Expr = node.(ExprNode) } - pl.Expr = node.(ExprNode) - node, ok = pl.Pattern.Accept(v) - if !ok { - return pl, false + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) + if !ok { + return n, false + } + n.Pattern = node.(ExprNode) } - pl.Pattern = node.(ExprNode) - return v.Leave(pl) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (pl *PatternLikeExpr) IsStatic() bool { - return pl.Expr.IsStatic() && pl.Pattern.IsStatic() +func (n *PatternLikeExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // ParamMarkerExpr expresion holds a place for another expression. // Used in parsing prepare statement. type ParamMarkerExpr struct { exprNode + Offset int } // Accept implements Node Accept interface. -func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(pm); skipChildren { - return pm, ok +func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(pm) + n = newNod.(*ParamMarkerExpr) + return v.Leave(n) } // ParenthesesExpr is the parentheses expression. @@ -552,23 +630,25 @@ type ParenthesesExpr struct { } // Accept implements Node Accept interface. -func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if p.Expr != nil { - node, ok := p.Expr.Accept(v) + n = newNod.(*ParenthesesExpr) + if n.Expr != nil { + node, ok := n.Expr.Accept(v) if !ok { - return p, false + return n, false } - p.Expr = node.(ExprNode) + n.Expr = node.(ExprNode) } - return v.Leave(p) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (p *ParenthesesExpr) IsStatic() bool { - return p.Expr.IsStatic() +func (n *ParenthesesExpr) IsStatic() bool { + return n.Expr.IsStatic() } // PositionExpr is the expression for order by and group by position. @@ -583,16 +663,18 @@ type PositionExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (p *PositionExpr) IsStatic() bool { +func (n *PositionExpr) IsStatic() bool { return true } // Accept implements Node Accept interface. -func (p *PositionExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (n *PositionExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(p) + n = newNod.(*PositionExpr) + return v.Leave(n) } // PatternRegexpExpr is the pattern expression for pattern match. @@ -607,26 +689,28 @@ type PatternRegexpExpr struct { } // Accept implements Node Accept interface. -func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(p); skipChildren { - return p, ok +func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := p.Expr.Accept(v) + n = newNod.(*PatternRegexpExpr) + node, ok := n.Expr.Accept(v) if !ok { - return p, false + return n, false } - p.Expr = node.(ExprNode) - node, ok = p.Pattern.Accept(v) + n.Expr = node.(ExprNode) + node, ok = n.Pattern.Accept(v) if !ok { - return p, false + return n, false } - p.Pattern = node.(ExprNode) - return v.Leave(p) + n.Pattern = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (p *PatternRegexpExpr) IsStatic() bool { - return p.Expr.IsStatic() && p.Pattern.IsStatic() +func (n *PatternRegexpExpr) IsStatic() bool { + return n.Expr.IsStatic() && n.Pattern.IsStatic() } // RowExpr is the expression for row constructor. @@ -638,23 +722,25 @@ type RowExpr struct { } // Accept implements Node Accept interface. -func (r *RowExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(r); skipChildren { - return r, ok +func (n *RowExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range r.Values { + n = newNod.(*RowExpr) + for i, val := range n.Values { node, ok := val.Accept(v) if !ok { - return r, false + return n, false } - r.Values[i] = node.(ExprNode) + n.Values[i] = node.(ExprNode) } - return v.Leave(r) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (r *RowExpr) IsStatic() bool { - for _, v := range r.Values { +func (n *RowExpr) IsStatic() bool { + for _, v := range n.Values { if !v.IsStatic() { return false } @@ -672,21 +758,23 @@ type UnaryOperationExpr struct { } // Accept implements Node Accept interface. -func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(u); skipChildren { - return u, ok +func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := u.V.Accept(v) + n = newNod.(*UnaryOperationExpr) + node, ok := n.V.Accept(v) if !ok { - return u, false + return n, false } - u.V = node.(ExprNode) - return v.Leave(u) + n.V = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (u *UnaryOperationExpr) IsStatic() bool { - return u.V.IsStatic() +func (n *UnaryOperationExpr) IsStatic() bool { + return n.V.IsStatic() } // ValuesExpr is the expression used in INSERT VALUES @@ -697,16 +785,18 @@ type ValuesExpr struct { } // Accept implements Node Accept interface. -func (va *ValuesExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (n *ValuesExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Column.Accept(v) + n = newNod.(*ValuesExpr) + node, ok := n.Column.Accept(v) if !ok { - return va, false + return n, false } - va.Column = node.(*ColumnName) - return v.Leave(va) + n.Column = node.(*ColumnName) + return v.Leave(n) } // VariableExpr is the expression for variable. @@ -721,9 +811,11 @@ type VariableExpr struct { } // Accept implements Node Accept interface. -func (va *VariableExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (n *VariableExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(va) + n = newNod.(*VariableExpr) + return v.Leave(n) } diff --git a/ast/functions.go b/ast/functions.go index 3843478313..7faab4c6e1 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -26,44 +26,49 @@ var ( _ FuncNode = &FuncConvertExpr{} _ FuncNode = &FuncCastExpr{} _ FuncNode = &FuncSubstringExpr{} + _ FuncNode = &FuncLocateExpr{} _ FuncNode = &FuncTrimExpr{} + _ FuncNode = &FuncDateArithExpr{} + _ FuncNode = &AggregateFuncExpr{} ) +// UnquoteString is not quoted when printed. +type UnquoteString string + // FuncCallExpr is for function expression. type FuncCallExpr struct { funcNode // F is the function name. - F string + FnName string // Args is the function args. Args []ExprNode - // Distinct only affetcts sum, avg, count, group_concat, - // so we can ignore it in other functions - Distinct bool } // Accept implements Node interface. -func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(c); skipChildren { - return c, ok +func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range c.Args { + n = newNod.(*FuncCallExpr) + for i, val := range n.Args { node, ok := val.Accept(v) if !ok { - return c, false + return n, false } - c.Args[i] = node.(ExprNode) + n.Args[i] = node.(ExprNode) } - return v.Leave(c) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (c *FuncCallExpr) IsStatic() bool { - v := builtin.Funcs[strings.ToLower(c.F)] +func (n *FuncCallExpr) IsStatic() bool { + v := builtin.Funcs[strings.ToLower(n.FnName)] if v.F == nil || !v.IsStatic { return false } - for _, v := range c.Args { + for _, v := range n.Args { if !v.IsStatic() { return false } @@ -81,21 +86,23 @@ type FuncExtractExpr struct { } // Accept implements Node Accept interface. -func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ex); skipChildren { - return ex, ok +func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ex.Date.Accept(v) + n = newNod.(*FuncExtractExpr) + node, ok := n.Date.Accept(v) if !ok { - return ex, false + return n, false } - ex.Date = node.(ExprNode) - return v.Leave(ex) + n.Date = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (ex *FuncExtractExpr) IsStatic() bool { - return ex.Date.IsStatic() +func (n *FuncExtractExpr) IsStatic() bool { + return n.Date.IsStatic() } // FuncConvertExpr provides a way to convert data between different character sets. @@ -109,21 +116,23 @@ type FuncConvertExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncConvertExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (n *FuncConvertExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + n = newNod.(*FuncConvertExpr) + node, ok := n.Expr.Accept(v) if !ok { - return f, false + return n, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + n.Expr = node.(ExprNode) + return v.Leave(n) } // CastFunctionType is the type for cast function. @@ -149,21 +158,23 @@ type FuncCastExpr struct { } // IsStatic implements the ExprNode IsStatic interface. -func (f *FuncCastExpr) IsStatic() bool { - return f.Expr.IsStatic() +func (n *FuncCastExpr) IsStatic() bool { + return n.Expr.IsStatic() } // Accept implements Node Accept interface. -func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(f); skipChildren { - return f, ok +func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := f.Expr.Accept(v) + n = newNod.(*FuncCastExpr) + node, ok := n.Expr.Accept(v) if !ok { - return f, false + return n, false } - f.Expr = node.(ExprNode) - return v.Leave(f) + n.Expr = node.(ExprNode) + return v.Leave(n) } // FuncSubstringExpr returns the substring as specified. @@ -177,31 +188,35 @@ type FuncSubstringExpr struct { } // Accept implements Node Accept interface. -func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(sf); skipChildren { - return sf, ok +func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := sf.StrExpr.Accept(v) + n = newNod.(*FuncSubstringExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return sf, false + return n, false } - sf.StrExpr = node.(ExprNode) - node, ok = sf.Pos.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return sf, false + return n, false } - sf.Pos = node.(ExprNode) - node, ok = sf.Len.Accept(v) - if !ok { - return sf, false + n.Pos = node.(ExprNode) + if n.Len != nil { + node, ok = n.Len.Accept(v) + if !ok { + return n, false + } + n.Len = node.(ExprNode) } - sf.Len = node.(ExprNode) - return v.Leave(sf) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (sf *FuncSubstringExpr) IsStatic() bool { - return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic() +func (n *FuncSubstringExpr) IsStatic() bool { + return n.StrExpr.IsStatic() && n.Pos.IsStatic() && n.Len.IsStatic() } // FuncSubstringIndexExpr returns the substring as specified. @@ -215,26 +230,28 @@ type FuncSubstringIndexExpr struct { } // Accept implements Node Accept interface. -func (si *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(si); skipChildren { - return si, ok +func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := si.StrExpr.Accept(v) + n = newNod.(*FuncSubstringIndexExpr) + node, ok := n.StrExpr.Accept(v) if !ok { - return si, false + return n, false } - si.StrExpr = node.(ExprNode) - node, ok = si.Delim.Accept(v) + n.StrExpr = node.(ExprNode) + node, ok = n.Delim.Accept(v) if !ok { - return si, false + return n, false } - si.Delim = node.(ExprNode) - node, ok = si.Count.Accept(v) + n.Delim = node.(ExprNode) + node, ok = n.Count.Accept(v) if !ok { - return si, false + return n, false } - si.Count = node.(ExprNode) - return v.Leave(si) + n.Count = node.(ExprNode) + return v.Leave(n) } // FuncLocateExpr returns the position of the first occurrence of substring. @@ -248,26 +265,28 @@ type FuncLocateExpr struct { } // Accept implements Node Accept interface. -func (le *FuncLocateExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(le); skipChildren { - return le, ok +func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := le.Str.Accept(v) + n = newNod.(*FuncLocateExpr) + node, ok := n.Str.Accept(v) if !ok { - return le, false + return n, false } - le.Str = node.(ExprNode) - node, ok = le.SubStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.SubStr.Accept(v) if !ok { - return le, false + return n, false } - le.SubStr = node.(ExprNode) - node, ok = le.Pos.Accept(v) + n.SubStr = node.(ExprNode) + node, ok = n.Pos.Accept(v) if !ok { - return le, false + return n, false } - le.Pos = node.(ExprNode) - return v.Leave(le) + n.Pos = node.(ExprNode) + return v.Leave(n) } // TrimDirectionType is the type for trim direction. @@ -295,27 +314,102 @@ type FuncTrimExpr struct { } // Accept implements Node Accept interface. -func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(tf); skipChildren { - return tf, ok +func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := tf.Str.Accept(v) + n = newNod.(*FuncTrimExpr) + node, ok := n.Str.Accept(v) if !ok { - return tf, false + return n, false } - tf.Str = node.(ExprNode) - node, ok = tf.RemStr.Accept(v) + n.Str = node.(ExprNode) + node, ok = n.RemStr.Accept(v) if !ok { - return tf, false + return n, false } - tf.RemStr = node.(ExprNode) - return v.Leave(tf) + n.RemStr = node.(ExprNode) + return v.Leave(n) } // IsStatic implements the ExprNode IsStatic interface. -func (tf *FuncTrimExpr) IsStatic() bool { - return tf.Str.IsStatic() && tf.RemStr.IsStatic() +func (n *FuncTrimExpr) IsStatic() bool { + return n.Str.IsStatic() && n.RemStr.IsStatic() } -// TypeStar is a special type for "*". -type TypeStar string +// DateArithType is type for DateArith option. +type DateArithType byte + +const ( + // DateAdd is to run date_add function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add + DateAdd DateArithType = iota + 1 + // DateSub is to run date_sub function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub + DateSub +) + +// FuncDateArithExpr is the struct for date arithmetic functions. +type FuncDateArithExpr struct { + funcNode + + Op DateArithType + Unit string + Date ExprNode + Interval ExprNode +} + +// Accept implements Node Accept interface. +func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*FuncDateArithExpr) + if n.Date != nil { + node, ok := n.Date.Accept(v) + if !ok { + return n, false + } + n.Date = node.(ExprNode) + } + if n.Interval != nil { + node, ok := n.Date.Accept(v) + if !ok { + return n, false + } + n.Date = node.(ExprNode) + } + return v.Leave(n) +} + +// AggregateFuncExpr represents aggregate function expression. +type AggregateFuncExpr struct { + funcNode + // F is the function name. + F string + // Args is the function args. + Args []ExprNode + // If distinct is true, the function only aggregate distinct values. + // For example, column c1 values are "1", "2", "2", "sum(c1)" is "5", + // but "sum(distinct c1)" is "3". + Distinct bool +} + +// Accept implements Node Accept interface. +func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*AggregateFuncExpr) + for i, val := range n.Args { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Args[i] = node.(ExprNode) + } + return v.Leave(n) +} diff --git a/ast/misc.go b/ast/misc.go index 2ad94e7755..4f287b4680 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -13,6 +13,8 @@ package ast +import "github.com/pingcap/tidb/mysql" + var ( _ StmtNode = &ExplainStmt{} _ StmtNode = &PrepareStmt{} @@ -26,7 +28,9 @@ var ( _ StmtNode = &SetStmt{} _ StmtNode = &SetCharsetStmt{} _ StmtNode = &SetPwdStmt{} + _ StmtNode = &CreateUserStmt{} _ StmtNode = &DoStmt{} + _ StmtNode = &GrantStmt{} _ Node = &VariableAssignment{} ) @@ -57,16 +61,18 @@ type ExplainStmt struct { } // Accept implements Node Accept interface. -func (es *ExplainStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (n *ExplainStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := es.Stmt.Accept(v) + n = newNod.(*ExplainStmt) + node, ok := n.Stmt.Accept(v) if !ok { - return es, false + return n, false } - es.Stmt = node.(DMLNode) - return v.Leave(es) + n.Stmt = node.(DMLNode) + return v.Leave(n) } // PrepareStmt is a statement to prepares a SQL statement which contains placeholders, @@ -83,16 +89,18 @@ type PrepareStmt struct { } // Accept implements Node Accept interface. -func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ps); skipChildren { - return ps, ok +func (n *PrepareStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := ps.SQLVar.Accept(v) + n = newNod.(*PrepareStmt) + node, ok := n.SQLVar.Accept(v) if !ok { - return ps, false + return n, false } - ps.SQLVar = node.(*VariableExpr) - return v.Leave(ps) + n.SQLVar = node.(*VariableExpr) + return v.Leave(n) } // DeallocateStmt is a statement to release PreparedStmt. @@ -105,11 +113,13 @@ type DeallocateStmt struct { } // Accept implements Node Accept interface. -func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ds); skipChildren { - return ds, ok +func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(ds) + n = newNod.(*DeallocateStmt) + return v.Leave(n) } // ExecuteStmt is a statement to execute PreparedStmt. @@ -123,18 +133,20 @@ type ExecuteStmt struct { } // Accept implements Node Accept interface. -func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(es); skipChildren { - return es, ok +func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range es.UsingVars { + n = newNod.(*ExecuteStmt) + for i, val := range n.UsingVars { node, ok := val.Accept(v) if !ok { - return es, false + return n, false } - es.UsingVars[i] = node.(ExprNode) + n.UsingVars[i] = node.(ExprNode) } - return v.Leave(es) + return v.Leave(n) } // ShowStmtType is the type for SHOW statement. @@ -152,12 +164,13 @@ const ( ShowVariables ShowCollation ShowCreateTable + ShowGrants ) // ShowStmt is a statement to provide information about databases, tables, columns and so on. // See: https://dev.mysql.com/doc/refman/5.7/en/show.html type ShowStmt struct { - stmtNode + dmlNode Tp ShowStmtType // Databases/Tables/Columns/.... DBName string @@ -165,6 +178,7 @@ type ShowStmt struct { Column *ColumnName // Used for `desc table column`. Flag int // Some flag parsed from sql, such as FULL. Full bool + User string // Used for show grants. // Used by show variables GlobalScope bool @@ -173,39 +187,41 @@ type ShowStmt struct { } // Accept implements Node Accept interface. -func (ss *ShowStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(ss); skipChildren { - return ss, ok +func (n *ShowStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - if ss.Table != nil { - node, ok := ss.Table.Accept(v) + n = newNod.(*ShowStmt) + if n.Table != nil { + node, ok := n.Table.Accept(v) if !ok { - return ss, false + return n, false } - ss.Table = node.(*TableName) + n.Table = node.(*TableName) } - if ss.Column != nil { - node, ok := ss.Column.Accept(v) + if n.Column != nil { + node, ok := n.Column.Accept(v) if !ok { - return ss, false + return n, false } - ss.Column = node.(*ColumnName) + n.Column = node.(*ColumnName) } - if ss.Pattern != nil { - node, ok := ss.Pattern.Accept(v) + if n.Pattern != nil { + node, ok := n.Pattern.Accept(v) if !ok { - return ss, false + return n, false } - ss.Pattern = node.(*PatternLikeExpr) + n.Pattern = node.(*PatternLikeExpr) } - if ss.Where != nil { - node, ok := ss.Where.Accept(v) + if n.Where != nil { + node, ok := n.Where.Accept(v) if !ok { - return ss, false + return n, false } - ss.Where = node.(ExprNode) + n.Where = node.(ExprNode) } - return v.Leave(ss) + return v.Leave(n) } // BeginStmt is a statement to start a new transaction. @@ -215,11 +231,13 @@ type BeginStmt struct { } // Accept implements Node Accept interface. -func (bs *BeginStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(bs); skipChildren { - return bs, ok +func (n *BeginStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(bs) + n = newNod.(*BeginStmt) + return v.Leave(n) } // CommitStmt is a statement to commit the current transaction. @@ -229,11 +247,13 @@ type CommitStmt struct { } // Accept implements Node Accept interface. -func (cs *CommitStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(cs); skipChildren { - return cs, ok +func (n *CommitStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(cs) + n = newNod.(*CommitStmt) + return v.Leave(n) } // RollbackStmt is a statement to roll back the current transaction. @@ -243,11 +263,13 @@ type RollbackStmt struct { } // Accept implements Node Accept interface. -func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(rs); skipChildren { - return rs, ok +func (n *RollbackStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(rs) + n = newNod.(*RollbackStmt) + return v.Leave(n) } // UseStmt is a statement to use the DBName database as the current database. @@ -259,11 +281,13 @@ type UseStmt struct { } // Accept implements Node Accept interface. -func (us *UseStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(us); skipChildren { - return us, ok +func (n *UseStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(us) + n = newNod.(*UseStmt) + return v.Leave(n) } // VariableAssignment is a variable assignment struct. @@ -276,16 +300,18 @@ type VariableAssignment struct { } // Accept implements Node interface. -func (va *VariableAssignment) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(va); skipChildren { - return va, ok +func (n *VariableAssignment) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - node, ok := va.Value.Accept(v) + n = newNod.(*VariableAssignment) + node, ok := n.Value.Accept(v) if !ok { - return va, false + return n, false } - va.Value = node.(ExprNode) - return v.Leave(va) + n.Value = node.(ExprNode) + return v.Leave(n) } // SetStmt is the statement to set variables. @@ -296,18 +322,20 @@ type SetStmt struct { } // Accept implements Node Accept interface. -func (set *SetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (n *SetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range set.Variables { + n = newNod.(*SetStmt) + for i, val := range n.Variables { node, ok := val.Accept(v) if !ok { - return set, false + return n, false } - set.Variables[i] = node.(*VariableAssignment) + n.Variables[i] = node.(*VariableAssignment) } - return v.Leave(set) + return v.Leave(n) } // SetCharsetStmt is a statement to assign values to character and collation variables. @@ -320,11 +348,13 @@ type SetCharsetStmt struct { } // Accept implements Node Accept interface. -func (set *SetCharsetStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + n = newNod.(*SetCharsetStmt) + return v.Leave(n) } // SetPwdStmt is a statement to assign a password to user account. @@ -337,11 +367,13 @@ type SetPwdStmt struct { } // Accept implements Node Accept interface. -func (set *SetPwdStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(set); skipChildren { - return set, ok +func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - return v.Leave(set) + n = newNod.(*SetPwdStmt) + return v.Leave(n) } // UserSpec is used for parsing create user statement. @@ -353,10 +385,20 @@ type UserSpec struct { // CreateUserStmt creates user account. // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html type CreateUserStmt struct { + stmtNode + IfNotExists bool Specs []*UserSpec +} - Text string +// Accept implements Node Accept interface. +func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*CreateUserStmt) + return v.Leave(n) } // DoStmt is the struct for DO statement. @@ -367,16 +409,100 @@ type DoStmt struct { } // Accept implements Node Accept interface. -func (do *DoStmt) Accept(v Visitor) (Node, bool) { - if skipChildren, ok := v.Enter(do); skipChildren { - return do, ok +func (n *DoStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) } - for i, val := range do.Exprs { + n = newNod.(*DoStmt) + for i, val := range n.Exprs { node, ok := val.Accept(v) if !ok { - return do, false + return n, false } - do.Exprs[i] = node.(ExprNode) + n.Exprs[i] = node.(ExprNode) } - return v.Leave(do) + return v.Leave(n) +} + +// PrivElem is the privilege type and optional column list. +type PrivElem struct { + node + Priv mysql.PrivilegeType + Cols []*ColumnName +} + +// Accept implements Node Accept interface. +func (n *PrivElem) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*PrivElem) + for i, val := range n.Cols { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Cols[i] = node.(*ColumnName) + } + return v.Leave(n) +} + +// ObjectTypeType is the type for object type. +type ObjectTypeType int + +const ( + // ObjectTypeNone is for empty object type. + ObjectTypeNone ObjectTypeType = iota + // ObjectTypeTable means the following object is a table. + ObjectTypeTable +) + +// GrantLevelType is the type for grant level. +type GrantLevelType int + +const ( + // GrantLevelNone is the dummy const for default value. + GrantLevelNone GrantLevelType = iota + // GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server. + GrantLevelGlobal + // GrantLevelDB means the privileges apply to all objects in a given database. + GrantLevelDB + // GrantLevelTable means the privileges apply to all columns in a given table. + GrantLevelTable +) + +// GrantLevel is used for store the privilege scope. +type GrantLevel struct { + Level GrantLevelType + DBName string + TableName string +} + +// GrantStmt is the struct for GRANT statement. +type GrantStmt struct { + stmtNode + + Privs []*PrivElem + ObjectType ObjectTypeType + Level *GrantLevel + Users []*UserSpec +} + +// Accept implements Node Accept interface. +func (n *GrantStmt) Accept(v Visitor) (Node, bool) { + newNod, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNod) + } + n = newNod.(*GrantStmt) + for i, val := range n.Privs { + node, ok := val.Accept(v) + if !ok { + return n, false + } + n.Privs[i] = node.(*PrivElem) + } + return v.Leave(n) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index 24cde087e6..1393a7dff1 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -104,6 +104,8 @@ import ( currentUser "CURRENT_USER" database "DATABASE" databases "DATABASES" + dateAdd "DATE_ADD" + dateSub "DATE_SUB" day "DAY" dayofmonth "DAYOFMONTH" dayofweek "DAYOFWEEK" @@ -141,6 +143,8 @@ import ( fulltext "FULLTEXT" ge ">=" global "GLOBAL" + grant "GRANT" + grants "GRANTS" group "GROUP" groupConcat "GROUP_CONCAT" having "HAVING" @@ -154,6 +158,7 @@ import ( index "INDEX" inner "INNER" insert "INSERT" + interval "INTERVAL" into "INTO" is "IS" join "JOIN" @@ -190,6 +195,7 @@ import ( nullIf "NULLIF" offset "OFFSET" on "ON" + option "OPTION" or "OR" order "ORDER" oror "||" @@ -204,6 +210,7 @@ import ( references "REFERENCES" regexp "REGEXP" repeat "REPEAT" + replace "REPLACE" right "RIGHT" rlike "RLIKE" rollback "ROLLBACK" @@ -229,11 +236,13 @@ import ( tableKwd "TABLE" tables "TABLES" then "THEN" + to "TO" trailing "TRAILING" transaction "TRANSACTION" trim "TRIM" trueKwd "true" truncate "TRUNCATE" + underscoreCS "UNDERSCORE_CHARSET" unknown "UNKNOWN" union "UNION" unique "UNIQUE" @@ -394,7 +403,7 @@ import ( FieldAsName "Field alias name" FieldAsNameOpt "Field alias name opt" FieldList "field expression list" - FromClause "From clause" + TableRefsClause "Table references clause" Function "function expr" FunctionCallAgg "Function call on aggregate data" FunctionCallConflict "Function call with reserved keyword as function name" @@ -403,8 +412,8 @@ import ( FunctionNameConflict "Built-in function call names which are conflict with keywords" FuncDatetimePrec "Function datetime precision" GlobalScope "The scope of variable" + GrantStmt "Grant statement" GroupByClause "GROUP BY clause" - GroupByList "GROUP BY list" HashString "Hashed string" HavingClause "HAVING clause" IfExists "If Exists" @@ -415,7 +424,7 @@ import ( IndexName "index name" IndexType "index type" InsertIntoStmt "INSERT INTO statement" - InsertRest "Rest part of INSERT INTO statement" + InsertValues "Rest part of INSERT/REPLACE INTO statement" IntoOpt "INTO or EmptyString" JoinTable "join table" JoinType "join type" @@ -431,15 +440,16 @@ import ( NotOpt "optional NOT" NowSym "CURRENT_TIMESTAMP/LOCALTIME/LOCALTIMESTAMP/NOW" NumLiteral "Num/Int/Float/Decimal Literal" + ObjectType "Grant statement object type" OnDuplicateKeyUpdate "ON DUPLICATE KEY UPDATE value list" Operand "operand" OptFull "Full or empty" OptInteger "Optional Integer keyword" Order "ORDER BY clause optional collation specification" OrderBy "ORDER BY clause" - OrderByItem "ORDER BY list item" + ByItem "BY item" OrderByOptional "Optional ORDER BY clause optional" - OrderByList "ORDER BY list" + ByList "BY list" OuterOpt "optional OUTER clause" QuickOptional "QUICK or empty" PasswordOpt "Password option" @@ -449,10 +459,15 @@ import ( PrimaryExpression "primary expression" PrimaryFactor "primary expression factor" Priority "insert statement priority" + PrivElem "Privilege element" + PrivElemList "Privilege element list" + PrivLevel "Privilege scope" + PrivType "Privilege type" ReferDef "Reference definition" RegexpSym "REGEXP or RLIKE" + ReplaceIntoStmt "REPLACE INTO statement" + ReplacePriority "replace statement priority" RollbackStmt "ROLLBACK statement" - SelectBasic "Basic SELECT statement without parentheses and union" SelectLockOpt "FOR UPDATE or LOCK IN SHARE MODE," SelectStmt "SELECT statement" SelectStmtCalcFoundRows "SELECT statement optional SQL_CALC_FOUND_ROWS" @@ -491,10 +506,9 @@ import ( TrimDirection "Trim string direction" TruncateTableStmt "TRANSACTION TABLE statement" UnionOpt "Union Option(empty/ALL/DISTINCT)" - UnionClause "Union select" + UnionStmt "Union select state ment" UnionClauseList "Union select clause list" - UnionClauseP "Union (select)" - UnionClausePList "Union (select) clause list" + UnionSelect "Union (select) item" UpdateStmt "UPDATE statement" Username "Username" UserSpec "Username and auth option" @@ -636,7 +650,7 @@ AlterTableSpec: { $$ = &ast.AlterTableSpec{ Tp: ast.AlterTableDropColumn, - ColumnName: $3.(*ast.ColumnName), + DropColumn: $3.(*ast.ColumnName), } } | "DROP" "PRIMARY" "KEY" @@ -776,13 +790,9 @@ ColumnNameListOpt: { $$ = []*ast.ColumnName{} } -| '(' ')' +| ColumnNameList { - $$ = []*ast.ColumnName{} - } -| '(' ColumnNameList ')' - { - $$ = $2.([]*ast.ColumnName) + $$ = $1.([]*ast.ColumnName) } CommitStmt: @@ -832,7 +842,7 @@ ColumnOption: { // See: https://dev.mysql.com/doc/refman/5.7/en/create-table.html // The CHECK clause is parsed but ignored by all storage engines. - $$ = nil + $$ = &ast.ColumnOption{} } ColumnOptionList: @@ -947,15 +957,15 @@ NowSym: SignedLiteral: Literal { - $$ = ast.ValueExpr{Val: $1} + $$ = ast.NewValueExpr($1) } | '+' NumLiteral { - $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: &ast.ValueExpr{Val: $2}} + $$ = &ast.UnaryOperationExpr{Op: opcode.Plus, V: ast.NewValueExpr($2)} } | '-' NumLiteral { - $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: &ast.ValueExpr{Val: $2}} + $$ = &ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr($2)} } // TODO: support decimal literal @@ -1141,17 +1151,19 @@ DeleteFromStmt: "DELETE" LowPriorityOptional QuickOptional IgnoreOptional "FROM" TableName WhereClauseOptional OrderByOptional LimitClause { // Single Table + join := &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil} x := &ast.DeleteStmt{ - TableRefs: &ast.Join{Left: &ast.TableSource{Source: $6.(ast.ResultSetNode)}, Right: nil}, + TableRefs: &ast.TableRefsClause{TableRefs: join}, LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - Order: $8.([]*ast.OrderByItem), } if $7 != nil { x.Where = $7.(ast.ExprNode) } - + if $8 != nil { + x.Order = $8.(*ast.OrderByClause) + } if $9 != nil { x.Limit = $9.(*ast.Limit) } @@ -1171,7 +1183,7 @@ DeleteFromStmt: MultiTable: true, BeforeFrom: true, Tables: $5.([]*ast.TableName), - TableRefs: $7.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, } if $8 != nil { x.Where = $8.(ast.ExprNode) @@ -1190,7 +1202,7 @@ DeleteFromStmt: Ignore: $4.(bool), MultiTable: true, Tables: $6.([]*ast.TableName), - TableRefs: $8.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, } if $9 != nil { x.Where = $9.(ast.ExprNode) @@ -1220,14 +1232,14 @@ DropIndexStmt: } DropTableStmt: - "DROP" "TABLE" TableNameList + "DROP" TableOrTables TableNameList { $$ = &ast.DropTableStmt{Tables: $3.([]*ast.TableName)} if yylex.(*lexer).root { break } } -| "DROP" "TABLE" "IF" "EXISTS" TableNameList +| "DROP" TableOrTables "IF" "EXISTS" TableNameList { $$ = &ast.DropTableStmt{IfExists: true, Tables: $5.([]*ast.TableName)} if yylex.(*lexer).root { @@ -1235,6 +1247,10 @@ DropTableStmt: } } +TableOrTables: + "TABLE" +| "TABLES" + EqOpt: { } @@ -1258,7 +1274,7 @@ ExplainStmt: { $$ = &ast.ExplainStmt{ Stmt: &ast.ShowStmt{ - Tp: ast.ShowTables, + Tp: ast.ShowColumns, Table: $2.(*ast.TableName), }, } @@ -1493,23 +1509,31 @@ Field: } | Identifier '.' '*' { - tn := &ast.TableName{Name:model.NewCIStr($1.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} + wildCard := &ast.WildCardField{Table: model.NewCIStr($1.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Identifier '.' Identifier '.' '*' { - tn := &ast.TableName{Schema:model.NewCIStr($1.(string)), Name:model.NewCIStr($3.(string))} - $$ = &ast.SelectField{WildCard: &ast.WildCardField{Table: tn}} + wildCard := &ast.WildCardField{Schema: model.NewCIStr($1.(string)), Table: model.NewCIStr($3.(string))} + $$ = &ast.SelectField{WildCard: wildCard} } | Expression FieldAsNameOpt { - $$ = &ast.SelectField{Expr: $1.(ast.ExprNode), AsName: $2.(model.CIStr)} + expr := $1.(ast.ExprNode) + asName := $2.(string) + if asName != "" { + // Set expr original text. + offset := yyS[yypt-1].offset + end := yyS[yypt].offset-1 + expr.SetText(yylex.(*lexer).src[offset:end]) + } + $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} } FieldAsNameOpt: /* EMPTY */ { - $$ = model.CIStr{} + $$ = "" } | FieldAsName { @@ -1537,27 +1561,28 @@ FieldAsName: FieldList: Field { - $$ = []*ast.SelectField{$1.(*ast.SelectField)} + field := $1.(*ast.SelectField) + field.Offset = yyS[yypt].offset + $$ = []*ast.SelectField{field} } | FieldList ',' Field { - $$ = append($1.([]*ast.SelectField), $3.(*ast.SelectField)) + + fl := $1.([]*ast.SelectField) + last := fl[len(fl)-1] + if last.Expr != nil && last.AsName.O == "" { + lastEnd := yyS[yypt-1].offset-1 // Comma offset. + last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) + } + newField := $3.(*ast.SelectField) + newField.Offset = yyS[yypt].offset + $$ = append(fl, newField) } GroupByClause: - "GROUP" "BY" GroupByList + "GROUP" "BY" ByList { - $$ = $3.([]ast.ExprNode) - } - -GroupByList: - Expression - { - $$ = []ast.ExprNode{$1.(ast.ExprNode)} - } -| GroupByList ',' Expression - { - $$ = append($1.([]ast.ExprNode), $3.(ast.ExprNode)) + $$ = &ast.GroupByClause{Items: $3.([]*ast.ByItem)} } HavingClause: @@ -1566,7 +1591,7 @@ HavingClause: } | "HAVING" Expression { - $$ = $2.(ast.ExprNode) + $$ = &ast.HavingClause{Expr: $2.(ast.ExprNode)} } IfExists: @@ -1625,10 +1650,10 @@ UnReservedKeyword: | "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" | "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" | "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" +| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" + "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" | "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" @@ -1639,11 +1664,13 @@ NotKeywordToken: * TODO: support PARTITION **********************************************************************************/ InsertIntoStmt: - "INSERT" Priority IgnoreOptional IntoOpt TableName InsertRest OnDuplicateKeyUpdate + "INSERT" Priority IgnoreOptional IntoOpt TableName InsertValues OnDuplicateKeyUpdate { x := $6.(*ast.InsertStmt) x.Priority = $2.(int) - x.Table = $5.(*ast.TableName) + // Wraps many layers here so that it can be processed the same way as select statement. + ts := &ast.TableSource{Source: $5.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} if $7 != nil { x.OnDuplicate = $7.([]*ast.Assignment) } @@ -1660,17 +1687,22 @@ IntoOpt: { } -InsertRest: +InsertValues: '(' ColumnNameListOpt ')' ValueSym ExpressionListList { $$ = &ast.InsertStmt{ Columns: $2.([]*ast.ColumnName), - Lists: $5.([][]ast.ExprNode)} + Lists: $5.([][]ast.ExprNode), + } } | '(' ColumnNameListOpt ')' SelectStmt { $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.SelectStmt)} } +| '(' ColumnNameListOpt ')' UnionStmt + { + $$ = &ast.InsertStmt{Columns: $2.([]*ast.ColumnName), Select: $4.(*ast.UnionStmt)} + } | ValueSym ExpressionListList %prec insertValues { $$ = &ast.InsertStmt{Lists: $2.([][]ast.ExprNode)} @@ -1679,6 +1711,10 @@ InsertRest: { $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} } +| UnionStmt + { + $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} + } | "SET" ColumnSetValueList { $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} @@ -1741,7 +1777,39 @@ OnDuplicateKeyUpdate: $$ = $5 } -/***********************************Insert Statments END************************************/ +/***********************************Insert Statements END************************************/ + +/************************************************************************************ + * Replace Statements + * See: https://dev.mysql.com/doc/refman/5.7/en/replace.html + * + * TODO: support PARTITION + **********************************************************************************/ +ReplaceIntoStmt: + "REPLACE" ReplacePriority IntoOpt TableName InsertValues + { + x := $5.(*ast.InsertStmt) + x.Replace = true + x.Priority = $2.(int) + ts := &ast.TableSource{Source: $4.(*ast.TableName)} + x.Table = &ast.TableRefsClause{TableRefs: &ast.Join{Left: ts}} + $$ = x + } + +ReplacePriority: + { + $$ = ast.NoPriority + } +| "LOW_PRIORITY" + { + $$ = ast.LowPriority + } +| "DELAYED" + { + $$ = ast.DelayedPriority + } + +/***********************************Replace Statments END************************************/ Literal: "false" @@ -1756,13 +1824,39 @@ Literal: | floatLit | intLit | stringLit + { + tp := types.NewFieldType(mysql.TypeString) + l := yylex.(*lexer) + tp.Charset, tp.Collate = l.GetCharsetInfo() + $$ = &types.DataItem{ + Type: tp, + Data: $1.(string), + } + } +| "UNDERSCORE_CHARSET" stringLit + { + // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + tp := types.NewFieldType(mysql.TypeString) + tp.Charset = $1.(string) + co, err := charset.GetDefaultCollation(tp.Charset) + if err != nil { + l := yylex.(*lexer) + l.errf("Get collation error for charset: %s", tp.Charset) + return 1 + } + tp.Collate = co + $$ = &types.DataItem{ + Type: tp, + Data: $2.(string), + } + } | hexLit | bitLit Operand: Literal { - $$ = &ast.ValueExpr{Val: $1} + $$ = ast.NewValueExpr($1) } | ColumnName { @@ -1786,7 +1880,9 @@ Operand: } | "PLACEHOLDER" { - $$ = &ast.ParamMarkerExpr{} + $$ = &ast.ParamMarkerExpr{ + Offset: yyS[yypt].offset, + } } | "ROW" '(' Expression ',' ExpressionList ')' { @@ -1804,25 +1900,25 @@ Operand: } OrderBy: - "ORDER" "BY" OrderByList + "ORDER" "BY" ByList { - $$ = $3.([]*ast.OrderByItem) + $$ = &ast.OrderByClause{Items: $3.([]*ast.ByItem)} } -OrderByList: - OrderByItem +ByList: + ByItem { - $$ = []*ast.OrderByItem{$1.(*ast.OrderByItem)} + $$ = []*ast.ByItem{$1.(*ast.ByItem)} } -| OrderByList ',' OrderByItem +| ByList ',' ByItem { - $$ = append($1.([]*ast.OrderByItem), $3.(*ast.OrderByItem)) + $$ = append($1.([]*ast.ByItem), $3.(*ast.ByItem)) } -OrderByItem: +ByItem: Expression Order { - $$ = &ast.OrderByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} + $$ = &ast.ByItem{Expr: $1.(ast.ExprNode), Desc: $2.(bool)} } Order: @@ -1898,16 +1994,16 @@ FunctionNameConflict: FunctionCallConflict: FunctionNameConflict '(' ExpressionListOpt ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: false} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURRENT_USER" { // See: https://dev.mysql.com/doc/refman/5.7/en/information-functions.html#function_current-user - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_DATE" { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } DistinctOpt: @@ -1928,18 +2024,14 @@ DistinctOpt: } FunctionCallKeyword: - "AVG" '(' DistinctOpt ExpressionList ')' - { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode), Distinct: $3.(bool)} - } -| "CAST" '(' Expression "AS" CastType ')' + "CAST" '(' Expression "AS" CastType ')' { /* See: https://dev.mysql.com/doc/refman/5.7/en/cast-functions.html#function_cast */ $$ = &ast.FuncCastExpr{ Expr: $3.(ast.ExprNode), Tp: $5.(*types.FieldType), FunctionType: ast.CastFunction, - } + } } | "CASE" ExpressionOpt WhenClauseList ElseOpt "END" { @@ -1971,34 +2063,34 @@ FunctionCallKeyword: } | "DATE" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "USER" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "VALUES" '(' ColumnName ')' %prec lowerThanInsertValues { // TODO: support qualified identifier for column_name - $$ = &ast.ColumnNameExpr{Name: $3.(*ast.ColumnName)} + $$ = &ast.ValuesExpr{Column: $3.(*ast.ColumnName)} } | "WEEK" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "YEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } FunctionCallNonKeyword: "COALESCE" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CURDATE" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "CURRENT_TIMESTAMP" FuncDatetimePrec { @@ -2006,35 +2098,53 @@ FunctionCallNonKeyword: if $2 != nil { args = append(args, $2.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "ABS" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "CONCAT" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "CONCAT_WS" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "DAY" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFWEEK" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFMONTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "DAYOFYEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + } +| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateAdd, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), + } + } +| "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' + { + $$ = &ast.FuncDateArithExpr{ + Op:ast.DateSub, + Unit: $7.(string), + Date: $3.(ast.ExprNode), + Interval: $6.(ast.ExprNode), + } } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' { @@ -2045,19 +2155,19 @@ FunctionCallNonKeyword: } | "FOUND_ROWS" '(' ')' { - $$ = &ast.FuncCallExpr{F: $1.(string)} + $$ = &ast.FuncCallExpr{FnName: $1.(string)} } | "HOUR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "IFNULL" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "LENGTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "LOCATE" '(' Expression ',' Expression ')' { @@ -2076,19 +2186,19 @@ FunctionCallNonKeyword: } | "LOWER" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MICROSECOND" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MINUTE" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "MONTH" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "NOW" '(' ExpressionOpt ')' { @@ -2096,11 +2206,11 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "NULLIF" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } | "RAND" '(' ExpressionOpt ')' { @@ -2109,11 +2219,16 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} + } +| "REPLACE" '(' Expression ',' Expression ',' Expression ')' + { + args := []ast.ExprNode{$3.(ast.ExprNode), $5.(ast.ExprNode), $7.(ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "SECOND" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "SUBSTRING" '(' Expression ',' Expression ')' { @@ -2159,7 +2274,7 @@ FunctionCallNonKeyword: if $3 != nil { args = append(args, $3.(ast.ExprNode)) } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: args} } | "TRIM" '(' Expression ')' { @@ -2191,19 +2306,19 @@ FunctionCallNonKeyword: } | "UPPER" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKDAY" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "WEEKOFYEAR" '(' Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } | "YEARWEEK" '(' ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $3.([]ast.ExprNode)} + $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } TrimDirection: @@ -2221,30 +2336,34 @@ TrimDirection: } FunctionCallAgg: - "COUNT" '(' DistinctOpt ExpressionList ')' + "AVG" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + } +| "COUNT" '(' DistinctOpt ExpressionList ')' + { + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "COUNT" '(' DistinctOpt '*' ')' { - args := []ast.ExprNode{&ast.ValueExpr{Val: ast.TypeStar("*")} } - $$ = &ast.FuncCallExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} + args := []ast.ExprNode{ast.NewValueExpr(ast.UnquoteString("*"))} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: args, Distinct: $3.(bool)} } | "GROUP_CONCAT" '(' DistinctOpt ExpressionList ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: $4.([]ast.ExprNode), Distinct: $3.(bool)} } | "MAX" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "MIN" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } | "SUM" '(' DistinctOpt Expression ')' { - $$ = &ast.FuncCallExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} + $$ = &ast.AggregateFuncExpr{F: $1.(string), Args: []ast.ExprNode{$4.(ast.ExprNode)}, Distinct: $3.(bool)} } FuncDatetimePrec: @@ -2556,57 +2675,85 @@ RollbackStmt: } SelectStmt: - SelectBasic -| SelectBasic UnionClauseList + "SELECT" SelectStmtOpts SelectStmtFieldList SelectStmtLimit SelectLockOpt { - st := $1.(*ast.SelectStmt) - st.Unions = $2.([]*ast.UnionClause) - $$ = st - } -| SubSelect UnionClausePList OrderByOptional SelectStmtLimit - { - st := $1.(*ast.SubqueryExpr).Query - st.Unions = $2.([]*ast.UnionClause) - st.UnionOrderBy = $3.([]*ast.OrderByItem) - st.UnionLimit = $4.(*ast.Limit) - $$ = st - } - -SelectBasic: - "SELECT" SelectStmtOpts SelectStmtFieldList FromDual SelectStmtLimit SelectLockOpt - { - $$ = &ast.SelectStmt { + st := &ast.SelectStmt { Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), - From: nil, - LockTp: $6.(ast.SelectLockType), + Fields: $3.(*ast.FieldList), + LockTp: $5.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + src := yylex.(*lexer).src + var lastEnd int + if $4 != nil { + lastEnd = yyS[yypt-1].offset-1 + } else if $5 != ast.SelectLockNone { + lastEnd = yyS[yypt].offset-1 + } else { + lastEnd = len(src) + if src[lastEnd-1] == ';' { + lastEnd-- + } + } + lastField.SetText(src[lastField.Offset:lastEnd]) + } + if $4 != nil { + st.Limit = $4.(*ast.Limit) + } + $$ = st + } +| "SELECT" SelectStmtOpts SelectStmtFieldList FromDual WhereClauseOptional SelectStmtLimit SelectLockOpt + { + st := &ast.SelectStmt { + Distinct: $2.(bool), + Fields: $3.(*ast.FieldList), + LockTp: $7.(ast.SelectLockType), + } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-3].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } + if $5 != nil { + st.Where = $5.(ast.ExprNode) + } + if $6 != nil { + st.Limit = $6.(*ast.Limit) + } + $$ = st } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional + TableRefsClause WhereClauseOptional SelectStmtGroup HavingClause OrderByOptional SelectStmtLimit SelectLockOpt { st := &ast.SelectStmt{ Distinct: $2.(bool), - Fields: $3.([]*ast.SelectField), - From: $5.(*ast.Join), + Fields: $3.(*ast.FieldList), + From: $5.(*ast.TableRefsClause), LockTp: $11.(ast.SelectLockType), } + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Expr != nil && lastField.AsName.O == "" { + lastEnd := yyS[yypt-7].offset-1 + lastField.SetText(yylex.(*lexer).src[lastField.Offset:lastEnd]) + } + if $6 != nil { st.Where = $6.(ast.ExprNode) } if $7 != nil { - st.GroupBy = $7.([]ast.ExprNode) + st.GroupBy = $7.(*ast.GroupByClause) } if $8 != nil { - st.Having = $8.(ast.ExprNode) + st.Having = $8.(*ast.HavingClause) } if $9 != nil { - st.OrderBy = $9.([]*ast.OrderByItem) + st.OrderBy = $9.(*ast.OrderByClause) } if $10 != nil { @@ -2617,21 +2764,20 @@ SelectBasic: } FromDual: - /* Empty */ -| "FROM" "DUAL" + "FROM" "DUAL" -FromClause: +TableRefsClause: TableRefs { - $$ = $1 + $$ = &ast.TableRefsClause{TableRefs: $1.(*ast.Join)} } TableRefs: EscapedTableRef { if j, ok := $1.(*ast.Join); ok { - // if $1 is JoinRset, use it directly + // if $1 is Join, use it directly $$ = j } else { $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: nil} @@ -2674,8 +2820,14 @@ TableFactor: } | '(' SelectStmt ')' TableAsName { + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } +| '(' UnionStmt ')' TableAsName + { + $$ = &ast.TableSource{Source: $2.(*ast.UnionStmt), AsName: $4.(model.CIStr)} + } | '(' TableRefs ')' { $$ = $2 @@ -2708,11 +2860,13 @@ JoinTable: } | TableRef CrossOpt TableRef "ON" Expression { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: $5.(ast.ExprNode)} + on := &ast.OnCondition{Expr: $5.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $3.(ast.ResultSetNode), Tp: ast.CrossJoin, On: on} } | TableRef JoinType OuterOpt "JOIN" TableRef "ON" Expression - { - $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: $7.(ast.ExprNode)} + { + on := &ast.OnCondition{Expr: $7.(ast.ExprNode)} + $$ = &ast.Join{Left: $1.(ast.ResultSetNode), Right: $5.(ast.ResultSetNode), Tp: $2.(ast.JoinType), On: on} } /* Support Using */ @@ -2762,7 +2916,7 @@ SelectStmtLimit: } | "LIMIT" LengthNum "OFFSET" LengthNum { - $$ = &ast.Limit{Offset: $2.(uint64), Count: $4.(uint64)} + $$ = &ast.Limit{Offset: $4.(uint64), Count: $2.(uint64)} } SelectStmtDistinct: @@ -2799,7 +2953,7 @@ SelectStmtCalcFoundRows: SelectStmtFieldList: FieldList { - $$ = $1 + $$ = &ast.FieldList{Fields: $1.([]*ast.SelectField)} } SelectStmtGroup: @@ -2814,6 +2968,15 @@ SubSelect: '(' SelectStmt ')' { s := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) + src := yylex.(*lexer).src + // See the implemention of yyParse function + s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) + $$ = &ast.SubqueryExpr{Query: s} + } +| '(' UnionStmt ')' + { + s := $2.(*ast.UnionStmt) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) @@ -2836,36 +2999,59 @@ SelectLockOpt: } // See: https://dev.mysql.com/doc/refman/5.7/en/union.html -UnionClause: - "UNION" UnionOpt SelectBasic +UnionStmt: + UnionClauseList "UNION" UnionOpt SelectStmt { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SelectStmt)} + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union } - -UnionClauseP: - "UNION" UnionOpt SubSelect +| UnionClauseList "UNION" UnionOpt '(' SelectStmt ')' OrderByOptional SelectStmtLimit { - $$ = &ast.UnionClause{Distinct: $2.(bool), Select: $3.(*ast.SubqueryExpr).Query} + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) + st := $5.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, st) + if $7 != nil { + union.OrderBy = $7.(*ast.OrderByClause) + } + if $8 != nil { + union.Limit = $8.(*ast.Limit) + } + $$ = union } UnionClauseList: - UnionClause + UnionSelect { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} + selects := []*ast.SelectStmt{$1.(*ast.SelectStmt)} + $$ = &ast.UnionStmt{ + Selects: selects, + } } -| UnionClauseList UnionClause +| UnionClauseList "UNION" UnionOpt UnionSelect { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) + union := $1.(*ast.UnionStmt) + union.Distinct = union.Distinct || $3.(bool) + lastSelect := union.Selects[len(union.Selects)-1] + yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) + $$ = union } -UnionClausePList: - UnionClauseP +UnionSelect: + SelectStmt +| '(' SelectStmt ')' { - $$ = []*ast.UnionClause{$1.(*ast.UnionClause)} - } -| UnionClausePList UnionClauseP - { - $$ = append($1.([]*ast.UnionClause), $2.(*ast.UnionClause)) + st := $2.(*ast.SelectStmt) + yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) + $$ = st } UnionOpt: @@ -3036,12 +3222,19 @@ ShowStmt: } | "SHOW" OptFull "TABLES" ShowDatabaseNameOpt ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowTables, DBName: $4.(string), Full: $2.(bool), - Where: $5.(ast.ExprNode), } + if $5 != nil { + if x, ok := $5.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { + stmt.Where = $5.(ast.ExprNode) + } + } + $$ = stmt } | "SHOW" OptFull "COLUMNS" ShowTableAliasOpt ShowDatabaseNameOpt { @@ -3058,26 +3251,49 @@ ShowStmt: } | "SHOW" GlobalScope "VARIABLES" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowVariables, GlobalScope: $2.(bool), - Where: $4.(ast.ExprNode), } + if x, ok := $4.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else if $4 != nil { + stmt.Where = $4.(ast.ExprNode) + } + $$ = stmt } | "SHOW" "COLLATION" ShowLikeOrWhereOpt { - $$ = &ast.ShowStmt{ + stmt := &ast.ShowStmt{ Tp: ast.ShowCollation, - Where: $3.(ast.ExprNode), } + if x, ok := $3.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else if $3 != nil { + stmt.Where = $3.(ast.ExprNode) + } + $$ = stmt } -| "SHOW" "CREATE" "TABLE" TableName +| "SHOW" "CREATE" "TABLE" TableName { $$ = &ast.ShowStmt{ - Tp: ast.ShowCreateTable, + Tp: ast.ShowCreateTable, Table: $4.(*ast.TableName), } } +| "SHOW" "GRANTS" + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &ast.ShowStmt{Tp: ast.ShowGrants} + } +| "SHOW" "GRANTS" "FOR" Username + { + // See: https://dev.mysql.com/doc/refman/5.7/en/show-grants.html + $$ = &ast.ShowStmt{ + Tp: ast.ShowGrants, + User: $4.(string), + } + } ShowLikeOrWhereOpt: { @@ -3154,10 +3370,13 @@ Statement: | DropDatabaseStmt | DropIndexStmt | DropTableStmt +| GrantStmt | InsertIntoStmt | PreparedStmt | RollbackStmt +| ReplaceIntoStmt | SelectStmt +| UnionStmt | SetStmt | ShowStmt | TruncateTableStmt @@ -3175,6 +3394,7 @@ ExplainableStmt: | DeleteFromStmt | UpdateStmt | InsertIntoStmt +| ReplaceIntoStmt StatementList: Statement @@ -3793,18 +4013,22 @@ StringName: UpdateStmt: "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { - // Single-table syntax - r := &ast.Join{Left: $4.(ast.ResultSetNode), Right: nil} + var refs *ast.Join + if x, ok := $4.(*ast.Join); ok { + refs = x + } else { + refs = &ast.Join{Left: $4.(ast.ResultSetNode)} + } st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: r, + TableRefs: &ast.TableRefsClause{TableRefs: refs}, List: $6.([]*ast.Assignment), } if $7 != nil { st.Where = $7.(ast.ExprNode) } if $8 != nil { - st.Order = $8.([]*ast.OrderByItem) + st.Order = $8.(*ast.OrderByClause) } if $9 != nil { st.Limit = $9.(*ast.Limit) @@ -3816,12 +4040,10 @@ UpdateStmt: } | "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional { - // Multiple-table syntax st := &ast.UpdateStmt{ LowPriority: $2.(bool), - TableRefs: $4.(*ast.Join), + TableRefs: &ast.TableRefsClause{TableRefs: $4.(*ast.Join)}, List: $6.([]*ast.Assignment), - MultipleTable: true, } if $7 != nil { st.Where = $7.(ast.ExprNode) @@ -3880,10 +4102,13 @@ CreateUserStmt: UserSpec: Username AuthOption { - $$ = &ast.UserSpec{ + userSpec := &ast.UserSpec{ User: $1.(string), - AuthOpt: $2.(*ast.AuthOption), } + if $2 != nil { + userSpec.AuthOpt = $2.(*ast.AuthOption) + } + $$ = userSpec } UserSpecList: @@ -3897,7 +4122,9 @@ UserSpecList: } AuthOption: - {} + { + $$ = nil + } | "IDENTIFIED" "BY" AuthString { $$ = &ast.AuthOption { @@ -3907,11 +4134,150 @@ AuthOption: } | "IDENTIFIED" "BY" "PASSWORD" HashString { - $$ = &ast.AuthOption { + $$ = &ast.AuthOption{ HashString: $4.(string), } } HashString: stringLit + +/************************************************************************************* + * Grant statement + * See: https://dev.mysql.com/doc/refman/5.7/en/grant.html + *************************************************************************************/ +GrantStmt: + "GRANT" PrivElemList "ON" ObjectType PrivLevel "TO" UserSpecList + { + $$ = &ast.GrantStmt{ + Privs: $2.([]*ast.PrivElem), + ObjectType: $4.(ast.ObjectTypeType), + Level: $5.(*ast.GrantLevel), + Users: $7.([]*ast.UserSpec), + } + } + +PrivElem: + PrivType + { + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + } + } +| PrivType '(' ColumnNameList ')' + { + $$ = &ast.PrivElem{ + Priv: $1.(mysql.PrivilegeType), + Cols: $3.([]*ast.ColumnName), + } + } + +PrivElemList: + PrivElem + { + $$ = []*ast.PrivElem{$1.(*ast.PrivElem)} + } +| PrivElemList ',' PrivElem + { + $$ = append($1.([]*ast.PrivElem), $3.(*ast.PrivElem)) + } + +PrivType: + "ALL" + { + $$ = mysql.AllPriv + } +| "ALTER" + { + $$ = mysql.AlterPriv + } +| "CREATE" + { + $$ = mysql.CreatePriv + } +| "CREATE" "USER" + { + $$ = mysql.CreateUserPriv + } +| "DELETE" + { + $$ = mysql.DeletePriv + } +| "DROP" + { + $$ = mysql.DropPriv + } +| "EXECUTE" + { + $$ = mysql.ExecutePriv + } +| "INDEX" + { + $$ = mysql.IndexPriv + } +| "INSERT" + { + $$ = mysql.InsertPriv + } +| "SELECT" + { + $$ = mysql.SelectPriv + } +| "SHOW" "DATABASES" + { + $$ = mysql.ShowDBPriv + } +| "UPDATE" + { + $$ = mysql.UpdatePriv + } +| "GRANT" "OPTION" + { + $$ = mysql.GrantPriv + } + +ObjectType: + { + $$ = ast.ObjectTypeNone + } +| "TABLE" + { + $$ = ast.ObjectTypeTable + } + +PrivLevel: + '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + } + } +| '*' '.' '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelGlobal, + } + } +| Identifier '.' '*' + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelDB, + DBName: $1.(string), + } + } +| Identifier '.' Identifier + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, + DBName: $1.(string), + TableName: $3.(string), + } + } +| Identifier + { + $$ = &ast.GrantLevel { + Level: ast.GrantLevelTable, + TableName: $1.(string), + } + } %% diff --git a/ast/parser/parser_test.go b/ast/parser/parser_test.go index 6b32ea65fc..3e8d7bb278 100644 --- a/ast/parser/parser_test.go +++ b/ast/parser/parser_test.go @@ -56,8 +56,8 @@ func (s *testParserSuite) TestSimple(c *C) { st := l.Stmts()[0] ss, ok := st.(*ast.SelectStmt) c.Assert(ok, IsTrue) - c.Assert(len(ss.Fields), Equals, 1) - cv, ok := ss.Fields[0].Expr.(*ast.FuncCastExpr) + c.Assert(len(ss.Fields.Fields), Equals, 1) + cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr) c.Assert(ok, IsTrue) c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction) @@ -77,3 +77,571 @@ func (s *testParserSuite) TestSimple(c *C) { c.Assert(ok, IsTrue) } } + +type testCase struct { + src string + ok bool +} + +func (s *testParserSuite) RunTest(c *C, table []testCase) { + for _, t := range table { + l := NewLexer(t.src) + ok := yyParse(l) == 0 + c.Assert(ok, Equals, t.ok, Commentf("source %v %v", t.src, l.errs)) + switch ok { + case true: + c.Assert(l.errs, HasLen, 0, Commentf("src: %s", t.src)) + case false: + c.Assert(len(l.errs), Not(Equals), 0, Commentf("src: %s", t.src)) + } + } +} +func (s *testParserSuite) TestDMLStmt(c *C) { + table := []testCase{ + {"", true}, + {";", true}, + {"INSERT INTO foo VALUES (1234)", true}, + {"INSERT INTO foo VALUES (1234, 5678)", true}, + // 15 + {"INSERT INTO foo VALUES (1 || 2)", true}, + {"INSERT INTO foo VALUES (1 | 2)", true}, + {"INSERT INTO foo VALUES (false || true)", true}, + {"INSERT INTO foo VALUES (bar(5678))", false}, + // 20 + {"INSERT INTO foo VALUES ()", true}, + {"SELECT * FROM t", true}, + {"SELECT * FROM t AS u", true}, + // 25 + {"SELECT * FROM t, v", true}, + {"SELECT * FROM t AS u, v", true}, + {"SELECT * FROM t, v AS w", true}, + {"SELECT * FROM t AS u, v AS w", true}, + {"SELECT * FROM foo, bar, foo", true}, + // 30 + {"SELECT DISTINCTS * FROM t", false}, + {"SELECT DISTINCT * FROM t", true}, + {"INSERT INTO foo (a) VALUES (42)", true}, + {"INSERT INTO foo (a,) VALUES (42,)", false}, + // 35 + {"INSERT INTO foo (a,b) VALUES (42,314)", true}, + {"INSERT INTO foo (a,b,) VALUES (42,314)", false}, + {"INSERT INTO foo (a,b,) VALUES (42,314,)", false}, + {"INSERT INTO foo () VALUES ()", true}, + {"INSERT INTO foo VALUE ()", true}, + + {"REPLACE INTO foo VALUES (1 || 2)", true}, + {"REPLACE INTO foo VALUES (1 | 2)", true}, + {"REPLACE INTO foo VALUES (false || true)", true}, + {"REPLACE INTO foo VALUES (bar(5678))", false}, + {"REPLACE INTO foo VALUES ()", true}, + {"REPLACE INTO foo (a,b) VALUES (42,314)", true}, + {"REPLACE INTO foo (a,b,) VALUES (42,314)", false}, + {"REPLACE INTO foo (a,b,) VALUES (42,314,)", false}, + {"REPLACE INTO foo () VALUES ()", true}, + {"REPLACE INTO foo VALUE ()", true}, + // 40 + {`SELECT stuff.id + FROM stuff + WHERE stuff.value >= ALL (SELECT stuff.value + FROM stuff)`, true}, + {"BEGIN", true}, + {"START TRANSACTION", true}, + // 45 + {"COMMIT", true}, + {"ROLLBACK", true}, + {` + BEGIN; + INSERT INTO foo VALUES (42, 3.14); + INSERT INTO foo VALUES (-1, 2.78); + COMMIT;`, true}, + {` // A + BEGIN; + INSERT INTO tmp SELECT * from bar; + SELECT * from tmp; + + // B + ROLLBACK;`, true}, + + // set + // user defined + {"SET @a = 1", true}, + // session system variables + {"SET SESSION autocommit = 1", true}, + {"SET @@session.autocommit = 1", true}, + {"SET LOCAL autocommit = 1", true}, + {"SET @@local.autocommit = 1", true}, + {"SET @@autocommit = 1", true}, + {"SET autocommit = 1", true}, + // global system variables + {"SET GLOBAL autocommit = 1", true}, + {"SET @@global.autocommit = 1", true}, + // SET CHARACTER SET + {"SET CHARACTER SET utf8mb4;", true}, + {"SET CHARACTER SET 'utf8mb4';", true}, + // Set password + {"SET PASSWORD = 'password';", true}, + {"SET PASSWORD FOR 'root'@'localhost' = 'password';", true}, + + // qualified select + {"SELECT a.b.c FROM t", true}, + {"SELECT a.b.*.c FROM t", false}, + {"SELECT a.b.* FROM t", true}, + {"SELECT a FROM t", true}, + {"SELECT a.b.c.d FROM t", false}, + + // Do statement + {"DO 1", true}, + {"DO 1 from t", false}, + + // Select for update + {"SELECT * from t for update", true}, + {"SELECT * from t lock in share mode", true}, + + // For alter table + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED", true}, + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED FIRST", true}, + {"ALTER TABLE t ADD COLUMN a SMALLINT UNSIGNED AFTER b", true}, + + // from join + {"SELECT * from t1, t2, t3", true}, + {"select * from t1 join t2 left join t3 on t2.id = t3.id", true}, + {"select * from t1 right join t2 on t1.id = t2.id left join t3 on t3.id = t2.id", true}, + {"select * from t1 right join t2 on t1.id = t2.id left join t3", false}, + + // For show full columns + {"show columns in t;", true}, + {"show full columns in t;", true}, + + // For set names + {"set names utf8", true}, + {"set names utf8 collate utf8_unicode_ci", true}, + + // For show character set + {"show character set;", true}, + // For on duplicate key update + {"INSERT INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, + {"INSERT IGNORE INTO t (a,b,c) VALUES (1,2,3),(4,5,6) ON DUPLICATE KEY UPDATE c=VALUES(a)+VALUES(b);", true}, + + // For SHOW statement + {"SHOW VARIABLES LIKE 'character_set_results'", true}, + {"SHOW GLOBAL VARIABLES LIKE 'character_set_results'", true}, + {"SHOW SESSION VARIABLES LIKE 'character_set_results'", true}, + {"SHOW VARIABLES", true}, + {"SHOW GLOBAL VARIABLES", true}, + {"SHOW GLOBAL VARIABLES WHERE Variable_name = 'autocommit'", true}, + {`SHOW FULL TABLES FROM icar_qa LIKE play_evolutions`, true}, + {`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true}, + {`SHOW GRANTS`, true}, + {`SHOW GRANTS FOR 'test'@'localhost'`, true}, + + // For default value + {"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true}, + + // For delete statement + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, + + // For update statement + {"UPDATE t SET id = id + 1 ORDER BY id DESC;", true}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id;", true}, + {"UPDATE items,month SET items.price=month.price WHERE items.id=month.id LIMIT 10;", false}, + {"UPDATE user T0 LEFT OUTER JOIN user_profile T1 ON T1.id = T0.profile_id SET T0.profile_id = 1 WHERE T0.profile_id IN (1);", true}, + + // For select with where clause + {"SELECT * FROM t WHERE 1 = 1", true}, + + // For show collation + {"show collation", true}, + {"show collation like 'utf8%'", true}, + {"show collation where Charset = 'utf8' and Collation = 'utf8_bin'", true}, + + // For dual + {"select 1 from dual", true}, + {"select 1 from dual limit 1", true}, + {"select 1 where exists (select 2)", false}, + {"select 1 from dual where not exists (select 2)", true}, + + // For show create table + {"show create table test.t", true}, + {"show create table t", true}, + + // For https://github.com/pingcap/tidb/issues/320 + {`(select 1);`, true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestExpression(c *C) { + table := []testCase{ + // Sign expression + {"SELECT ++1", true}, + {"SELECT -*1", false}, + {"SELECT -+1", true}, + {"SELECT -1", true}, + {"SELECT --1", true}, + + // For string literal + {`select '''a''', """a"""`, true}, + {`select ''a''`, false}, + {`select ""a""`, false}, + {`select '''a''';`, true}, + {`select '\'a\'';`, true}, + {`select "\"a\"";`, true}, + {`select """a""";`, true}, + {`select _utf8"string";`, true}, + // For comparison + {"select 1 <=> 0, 1 <=> null, 1 = null", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestBuiltin(c *C) { + table := []testCase{ + // For buildin functions + {"SELECT DAYOFMONTH('2007-02-03');", true}, + {"SELECT RAND();", true}, + {"SELECT RAND(1);", true}, + + {"SELECT SUBSTRING('Quadratically',5);", true}, + {"SELECT SUBSTRING('Quadratically',5, 3);", true}, + {"SELECT SUBSTRING('Quadratically' FROM 5);", true}, + {"SELECT SUBSTRING('Quadratically' FROM 5 FOR 3);", true}, + + {"SELECT CONVERT('111', SIGNED);", true}, + + {"SELECT DATABASE();", true}, + {"SELECT USER();", true}, + {"SELECT CURRENT_USER();", true}, + {"SELECT CURRENT_USER;", true}, + + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', 2);", true}, + {"SELECT SUBSTRING_INDEX('www.mysql.com', '.', -2);", true}, + + {`SELECT LOWER("A"), UPPER("a")`, true}, + + {`SELECT REPLACE('www.mysql.com', 'w', 'Ww')`, true}, + + {`SELECT LOCATE('bar', 'foobarbar');`, true}, + {`SELECT LOCATE('bar', 'foobarbar', 5);`, true}, + + {"select current_date, current_date(), curdate()", true}, + + // For delete statement + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE FROM t1, t2 USING t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;", true}, + {"DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id limit 10;", false}, + + // For time fsp + {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, + + // For row + {"select row(1)", false}, + {"select row(1, 1,)", false}, + {"select (1, 1,)", false}, + {"select row(1, 1) > row(1, 1), row(1, 1, 1) > row(1, 1, 1)", true}, + {"Select (1, 1) > (1, 1)", true}, + {"create table t (row int)", true}, + + // For cast with charset + {"SELECT *, CAST(data AS CHAR CHARACTER SET utf8) FROM t;", true}, + + // For binary operator + {"SELECT binary 'a';", true}, + + // Select time + {"select current_timestamp", true}, + {"select current_timestamp()", true}, + {"select current_timestamp(6)", true}, + {"select now()", true}, + {"select now(6)", true}, + {"select sysdate(), sysdate(6)", true}, + + // For time extract + {`select extract(microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(week from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(month from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(quarter from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(year from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(second_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(minute_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(hour_minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_microsecond from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_second from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_minute from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(day_hour from "2011-11-11 10:10:10.123456")`, true}, + {`select extract(year_month from "2011-11-11 10:10:10.123456")`, true}, + + // For issue 224 + {`SELECT CAST('test collated returns' AS CHAR CHARACTER SET utf8) COLLATE utf8_bin;`, true}, + + // For trim + {`SELECT TRIM(' bar ');`, true}, + {`SELECT TRIM(LEADING 'x' FROM 'xxxbarxxx');`, true}, + {`SELECT TRIM(BOTH 'x' FROM 'xxxbarxxx');`, true}, + {`SELECT TRIM(TRAILING 'xyz' FROM 'barxxyz');`, true}, + + // For date_add + {`select date_add("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + + // For date_sub + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestIdentifier(c *C) { + table := []testCase{ + // For quote identifier + {"select `a`, `a.b`, `a b` from t", true}, + // For unquoted identifier + {"create table MergeContextTest$Simple (value integer not null, primary key (value))", true}, + // For as + {"select 1 as a, 1 as `a`, 1 as \"a\", 1 as 'a'", true}, + {`select 1 as a, 1 as "a", 1 as 'a'`, true}, + {`select 1 a, 1 "a", 1 'a'`, true}, + {`select * from t as "a"`, false}, + {`select * from t a`, true}, + {`select * from t as a`, true}, + {"select 1 full, 1 row, 1 abs", true}, + {"select * from t full, t1 row, t2 abs", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestDDL(c *C) { + table := []testCase{ + {"CREATE", false}, + {"CREATE TABLE", false}, + {"CREATE TABLE foo (", false}, + {"CREATE TABLE foo ()", false}, + {"CREATE TABLE foo ();", false}, + {"CREATE TABLE foo (a TINYINT UNSIGNED);", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, + {"CREATE TABLE foo (a bigint unsigned, b bool);", true}, + {"CREATE TABLE foo (a TINYINT, b SMALLINT) CREATE TABLE bar (x INT, y int64)", false}, + {"CREATE TABLE foo (a int, b float); CREATE TABLE bar (x double, y float)", true}, + {"CREATE TABLE foo (a bytes)", false}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED)", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) -- foo", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) // foo", true}, + {"CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, + {"CREATE TABLE foo /* foo */ (a SMALLINT UNSIGNED, b INT UNSIGNED) /* foo */", true}, + {"CREATE TABLE foo (name CHAR(50) BINARY)", true}, + {"CREATE TABLE foo (name CHAR(50) COLLATE utf8_bin)", true}, + {"CREATE TABLE foo (name CHAR(50) CHARACTER SET utf8)", true}, + {"CREATE TABLE foo (name CHAR(50) BINARY CHARACTER SET utf8 COLLATE utf8_bin)", true}, + + {"CREATE TABLE foo (a.b, b);", false}, + {"CREATE TABLE foo (a, b.c);", false}, + // For table option + {"create table t (c int) avg_row_length = 3", true}, + {"create table t (c int) avg_row_length 3", true}, + {"create table t (c int) checksum = 0", true}, + {"create table t (c int) checksum 1", true}, + {"create table t (c int) compression = none", true}, + {"create table t (c int) compression lz4", true}, + {"create table t (c int) connection = 'abc'", true}, + {"create table t (c int) connection 'abc'", true}, + {"create table t (c int) key_block_size = 1024", true}, + {"create table t (c int) key_block_size 1024", true}, + {"create table t (c int) max_rows = 1000", true}, + {"create table t (c int) max_rows 1000", true}, + {"create table t (c int) min_rows = 1000", true}, + {"create table t (c int) min_rows 1000", true}, + {"create table t (c int) password = 'abc'", true}, + {"create table t (c int) password 'abc'", true}, + // For check clause + {"create table t (c1 bool, c2 bool, check (c1 in (0, 1)), check (c2 in (0, 1)))", true}, + {"CREATE TABLE Customer (SD integer CHECK (SD > 0), First_Name varchar(30));", true}, + + {"create database xxx", true}, + {"create database if exists xxx", false}, + {"create database if not exists xxx", true}, + {"create schema xxx", true}, + {"create schema if exists xxx", false}, + {"create schema if not exists xxx", true}, + // For drop datbase/schema/table + {"drop database xxx", true}, + {"drop database if exists xxx", true}, + {"drop database if not exists xxx", false}, + {"drop schema xxx", true}, + {"drop schema if exists xxx", true}, + {"drop schema if not exists xxx", false}, + {"drop table xxx", true}, + {"drop table xxx, yyy", true}, + {"drop tables xxx", true}, + {"drop tables xxx, yyy", true}, + {"drop table if exists xxx", true}, + {"drop table if not exists xxx", false}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestType(c *C) { + table := []testCase{ + // For time fsp + {"CREATE TABLE t( c1 TIME(2), c2 DATETIME(2), c3 TIMESTAMP(2) );", true}, + + // For hexadecimal + {"SELECT x'0a', X'11', 0x11", true}, + {"select x'0xaa'", false}, + {"select 0X11", false}, + + // For bit + {"select 0b01, 0b0, b'11', B'11'", true}, + {"select 0B01", false}, + {"select 0b21", false}, + + // For enum and set type + {"create table t (c1 enum('a', 'b'), c2 set('a', 'b'))", true}, + {"create table t (c1 enum)", false}, + {"create table t (c1 set)", false}, + + // For blob and text field length + {"create table t (c1 blob(1024), c2 text(1024))", true}, + + // For year + {"create table t (y year(4), y1 year)", true}, + + // For national + {"create table t (c1 national char(2), c2 national varchar(2))", true}, + + // For https://github.com/pingcap/tidb/issues/312 + {`create table t (c float(53));`, true}, + {`create table t (c float(54));`, false}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestPrivilege(c *C) { + table := []testCase{ + // For create user + {`CREATE USER IF NOT EXISTS 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password', 'root'@'127.0.0.1' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + + // For grant statement + {"GRANT ALL ON db1.* TO 'jeffrey'@'localhost';", true}, + {"GRANT SELECT ON db2.invoice TO 'jeffrey'@'localhost';", true}, + {"GRANT ALL ON *.* TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON *.* TO 'someuser'@'somehost';", true}, + {"GRANT ALL ON mydb.* TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON mydb.* TO 'someuser'@'somehost';", true}, + {"GRANT ALL ON mydb.mytbl TO 'someuser'@'somehost';", true}, + {"GRANT SELECT, INSERT ON mydb.mytbl TO 'someuser'@'somehost';", true}, + {"GRANT SELECT (col1), INSERT (col1,col2) ON mydb.mytbl TO 'someuser'@'somehost';", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestComment(c *C) { + table := []testCase{ + {"create table t (c int comment 'comment')", true}, + {"create table t (c int) comment = 'comment'", true}, + {"create table t (c int) comment 'comment'", true}, + {"create table t (c int) comment comment", false}, + {"create table t (comment text)", true}, + // For comment in query + {"/*comment*/ /*comment*/ select c /* this is a comment */ from t;", true}, + } + s.RunTest(c, table) +} +func (s *testParserSuite) TestSubquery(c *C) { + table := []testCase{ + // For compare subquery + {"SELECT 1 > (select 1)", true}, + {"SELECT 1 > ANY (select 1)", true}, + {"SELECT 1 > ALL (select 1)", true}, + {"SELECT 1 > SOME (select 1)", true}, + + // For exists subquery + {"SELECT EXISTS select 1", false}, + {"SELECT EXISTS (select 1)", true}, + {"SELECT + EXISTS (select 1)", true}, + {"SELECT - EXISTS (select 1)", true}, + {"SELECT NOT EXISTS (select 1)", true}, + {"SELECT + NOT EXISTS (select 1)", false}, + {"SELECT - NOT EXISTS (select 1)", false}, + } + s.RunTest(c, table) +} +func (s *testParserSuite) TestUnion(c *C) { + table := []testCase{ + {"select c1 from t1 union select c2 from t2", true}, + {"select c1 from t1 union (select c2 from t2)", true}, + {"select c1 from t1 union (select c2 from t2) order by c1", true}, + {"select c1 from t1 union select c2 from t2 order by c2", true}, + {"select c1 from t1 union (select c2 from t2) limit 1", true}, + {"select c1 from t1 union (select c2 from t2) limit 1, 1", true}, + {"select c1 from t1 union (select c2 from t2) order by c1 limit 1", true}, + {"(select c1 from t1) union distinct select c2 from t2", true}, + {"(select c1 from t1) union all select c2 from t2", true}, + {"(select c1 from t1) union (select c2 from t2) order by c1 union select c3 from t3", false}, + {"(select c1 from t1) union (select c2 from t2) limit 1 union select c3 from t3", false}, + {"(select c1 from t1) union select c2 from t2 union (select c3 from t3) order by c1 limit 1", true}, + {"select (select 1 union select 1) as a", true}, + {"select * from (select 1 union select 2) as a", true}, + {"insert into t select c1 from t1 union select c2 from t2", true}, + {"insert into t (c) select c1 from t1 union select c2 from t2", true}, + } + s.RunTest(c, table) +} + +func (s *testParserSuite) TestLikeEscape(c *C) { + table := []testCase{ + // For like escape + {`select "abc_" like "abc\\_" escape ''`, true}, + {`select "abc_" like "abc\\_" escape '\\'`, true}, + {`select "abc_" like "abc\\_" escape '||'`, false}, + {`select "abc" like "escape" escape '+'`, true}, + } + + s.RunTest(c, table) +} diff --git a/ast/parser/scanner.l b/ast/parser/scanner.l index a23a68a72a..9ee7b1b8a9 100644 --- a/ast/parser/scanner.l +++ b/ast/parser/scanner.l @@ -27,6 +27,8 @@ import ( "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/charset" + "github.com/pingcap/tidb/util/stringutil" ) type lexer struct { @@ -46,12 +48,17 @@ type lexer struct { val []byte ungetBuf []byte root bool + prepare bool stmtStartPos int stringLit []byte // record token's offset of the input tokenEndOffset int tokenStartOffset int + + // Charset information + charset string + collation string } @@ -86,6 +93,14 @@ func (l *lexer) SetInj(inj int) { l.inj = inj } +func (l *lexer) SetPrepare() { + l.prepare = true +} + +func (l *lexer) IsPrepare() bool { + return l.prepare +} + func (l *lexer) Root() bool { return l.root } @@ -94,6 +109,25 @@ func (l *lexer) SetRoot(root bool) { l.root = root } +func (l *lexer) SetCharsetInfo(charset, collation string) { + l.charset = charset + l.collation = collation +} + +func (l *lexer) GetCharsetInfo() (string, string) { + return l.charset, l.collation +} + +// The select statement is not at the end of the whole statement, if the last +// field text was set from its offset to the end of the src string, update +// the last field text. +func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { + lastField := st.Fields.Fields[len(st.Fields.Fields)-1] + if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 { + lastField.SetText(l.src[lastField.Offset:lastEnd]) + } +} + func (l *lexer) unget(b byte) { l.ungetBuf = append(l.ungetBuf, b) l.i-- @@ -273,6 +307,8 @@ current_date {c}{u}{r}{r}{e}{n}{t}_{d}{a}{t}{e} current_user {c}{u}{r}{r}{e}{n}{t}_{u}{s}{e}{r} database {d}{a}{t}{a}{b}{a}{s}{e} databases {d}{a}{t}{a}{b}{a}{s}{e}{s} +date_add {d}{a}{t}{e}_{a}{d}{d} +date_sub {d}{a}{t}{e}_{s}{u}{b} day {d}{a}{y} dayofweek {d}{a}{y}{o}{f}{w}{e}{e}{k} dayofmonth {d}{a}{y}{o}{f}{m}{o}{n}{t}{h} @@ -306,6 +342,8 @@ from {f}{r}{o}{m} full {f}{u}{l}{l} fulltext {f}{u}{l}{l}{t}{e}{x}{t} global {g}{l}{o}{b}{a}{l} +grant {g}{r}{a}{n}{t} +grants {g}{r}{a}{n}{t}{s} group {g}{r}{o}{u}{p} group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} @@ -319,6 +357,7 @@ in {i}{n} index {i}{n}{d}{e}{x} inner {i}{n}{n}{e}{r} insert {i}{n}{s}{e}{r}{t} +interval {i}{n}{t}{e}{r}{v}{a}{l} into {i}{n}{t}{o} is {i}{s} join {j}{o}{i}{n} @@ -346,6 +385,7 @@ national {n}{a}{t}{i}{o}{n}{a}{l} not {n}{o}{t} offset {o}{f}{f}{s}{e}{t} on {o}{n} +option {o}{p}{t}{i}{o}{n} or {o}{r} order {o}{r}{d}{e}{r} outer {o}{u}{t}{e}{r} @@ -358,6 +398,7 @@ rand {r}{a}{n}{d} repeat {r}{e}{p}{e}{a}{t} references {r}{e}{f}{e}{r}{e}{n}{c}{e}{s} regexp {r}{e}{g}{e}{x}{p} +replace {r}{e}{p}{l}{a}{c}{e} right {r}{i}{g}{h}{t} rlike {r}{l}{i}{k}{e} rollback {r}{o}{l}{l}{b}{a}{c}{k} @@ -379,6 +420,7 @@ sysdate {s}{y}{s}{d}{a}{t}{e} table {t}{a}{b}{l}{e} tables {t}{a}{b}{l}{e}{s} then {t}{h}{e}{n} +to {t}{o} trailing {t}{r}{a}{i}{l}{i}{n}{g} transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n} trim {t}{r}{i}{m} @@ -616,6 +658,10 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {database} lval.item = string(l.val) return database {databases} return databases +{date_add} lval.item = string(l.val) + return dateAdd +{date_sub} lval.item = string(l.val) + return dateSub {day} lval.item = string(l.val) return day {dayofweek} lval.item = string(l.val) @@ -673,6 +719,9 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {full} lval.item = string(l.val) return full {fulltext} return fulltext +{grant} return grant +{grants} lval.item = string(l.val) + return grants {group} return group {group_concat} lval.item = string(l.val) return groupConcat @@ -696,6 +745,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {index} return index {inner} return inner {insert} return insert +{interval} return interval {into} return into {in} return in {is} return is @@ -747,6 +797,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {offset} lval.item = string(l.val) return offset {on} return on +{option} return option {order} return order {or} return or {outer} return outer @@ -780,6 +831,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {repeat} lval.item = string(l.val) return repeat {regexp} return regexp +{replace} lval.item = string(l.val) + return replace {references} return references {rlike} return rlike @@ -809,6 +862,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {tables} lval.item = string(l.val) return tables {then} return then +{to} return to {trailing} return trailing {transaction} lval.item = string(l.val) return transaction @@ -975,7 +1029,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} return integerType {ident} lval.item = string(l.val) - return identifier + return l.handleIdent(lval) . return c0 @@ -1001,7 +1055,8 @@ func (l *lexer) str(lval *yySymType, pref string) int { s = strings.TrimSuffix(s, "'") + "\"" pref = "\"" } - v, err := strconv.Unquote(pref + s) + v := stringutil.RemoveUselessBackslash(pref+s) + v, err := strconv.Unquote(v) if err != nil { v = strings.TrimSuffix(s, pref) } @@ -1065,3 +1120,19 @@ func (l *lexer) bit(lval *yySymType) int { lval.item = b return bitLit } + +func (l *lexer) handleIdent(lval *yySymType) int { + s := lval.item.(string) + // A character string literal may have an optional character set introducer and COLLATE clause: + // [_charset_name]'string' [COLLATE collation_name] + // See: https://dev.mysql.com/doc/refman/5.7/en/charset-literal.html + if !strings.HasPrefix(s, "_") { + return identifier + } + cs, _, err := charset.GetCharsetInfo(s[1:]) + if err != nil { + return identifier + } + lval.item = cs + return underscoreCS +} diff --git a/ast/parser/yy_parser.go b/ast/parser/yy_parser.go new file mode 100644 index 0000000000..dfc3264796 --- /dev/null +++ b/ast/parser/yy_parser.go @@ -0,0 +1,19 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package parser + +// YYParse is an wrapper of `yyParse` to make it exported. +func YYParse(yylex yyLexer) int { + return yyParse(yylex) +} diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 864567a29f..a1805a0bcd 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -19,12 +19,13 @@ import ( "github.com/ngaut/log" . "github.com/pingcap/check" "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/optimizer" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/stmt/stmts" @@ -198,5 +199,7 @@ func statement(sql string) stmt.Statement { log.Debug("Compile", sql) lexer := parser.NewLexer(sql) parser.YYParse(lexer) - return lexer.Stmts()[0].(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, _ := compiler.Compile(lexer.Stmts()[0]) + return stm } diff --git a/optimizer/aggregator.go b/optimizer/aggregator.go new file mode 100644 index 0000000000..452d047c47 --- /dev/null +++ b/optimizer/aggregator.go @@ -0,0 +1,26 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +// Aggregator is the interface to +// compute aggregate function result. +type Aggregator interface { + // Input adds an input value to aggregator. + // The input values are accumulated in the aggregator. + Input(in ...interface{}) error + // Output uses input values to compute the aggregated result. + Output() interface{} + // Clear clears the input values. + Clear() +} diff --git a/optimizer/compiler.go b/optimizer/compiler.go new file mode 100644 index 0000000000..a9e42cd8b0 --- /dev/null +++ b/optimizer/compiler.go @@ -0,0 +1,132 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "sort" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/stmt" +) + +// Compiler compiles ast.Node into an executable statement. +type Compiler struct { + converter *expressionConverter +} + +// Compile compiles a ast.Node into an executable statement. +func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) { + validator := &validator{} + if _, ok := node.Accept(validator); !ok { + return nil, errors.Trace(validator.err) + } + + // binder := &InfoBinder{} + // if _, ok := node.Accept(validator); !ok { + // return nil, errors.Trace(binder.Err) + // } + + tpComputer := &typeComputer{} + if _, ok := node.Accept(tpComputer); !ok { + return nil, errors.Trace(tpComputer.err) + } + c := newExpressionConverter() + com.converter = c + switch v := node.(type) { + case *ast.InsertStmt: + return convertInsert(c, v) + case *ast.DeleteStmt: + return convertDelete(c, v) + case *ast.UpdateStmt: + return convertUpdate(c, v) + case *ast.SelectStmt: + return convertSelect(c, v) + case *ast.UnionStmt: + return convertUnion(c, v) + case *ast.CreateDatabaseStmt: + return convertCreateDatabase(c, v) + case *ast.DropDatabaseStmt: + return convertDropDatabase(c, v) + case *ast.CreateTableStmt: + return convertCreateTable(c, v) + case *ast.DropTableStmt: + return convertDropTable(c, v) + case *ast.CreateIndexStmt: + return convertCreateIndex(c, v) + case *ast.DropIndexStmt: + return convertDropIndex(c, v) + case *ast.AlterTableStmt: + return convertAlterTable(c, v) + case *ast.TruncateTableStmt: + return convertTruncateTable(c, v) + case *ast.ExplainStmt: + return convertExplain(c, v) + case *ast.PrepareStmt: + return convertPrepare(c, v) + case *ast.DeallocateStmt: + return convertDeallocate(c, v) + case *ast.ExecuteStmt: + return convertExecute(c, v) + case *ast.ShowStmt: + return convertShow(c, v) + case *ast.BeginStmt: + return convertBegin(c, v) + case *ast.CommitStmt: + return convertCommit(c, v) + case *ast.RollbackStmt: + return convertRollback(c, v) + case *ast.UseStmt: + return convertUse(c, v) + case *ast.SetStmt: + return convertSet(c, v) + case *ast.SetCharsetStmt: + return convertSetCharset(c, v) + case *ast.SetPwdStmt: + return convertSetPwd(c, v) + case *ast.CreateUserStmt: + return convertCreateUser(c, v) + case *ast.DoStmt: + return convertDo(c, v) + case *ast.GrantStmt: + return convertGrant(c, v) + } + return nil, nil +} + +type paramMarkers []*ast.ParamMarkerExpr + +func (p paramMarkers) Len() int { + return len(p) +} + +func (p paramMarkers) Less(i, j int) bool { + return p[i].Offset < p[j].Offset +} + +func (p paramMarkers) Swap(i, j int) { + p[i], p[j] = p[j], p[i] +} + +// ParamMarkers returns parameter markers for prepared statement. +func (com *Compiler) ParamMarkers() []*expression.ParamMarker { + c := com.converter + sort.Sort(c.paramMarkers) + oldMarkers := make([]*expression.ParamMarker, len(c.paramMarkers)) + for i, val := range c.paramMarkers { + oldMarkers[i] = c.exprMap[val].(*expression.ParamMarker) + } + return oldMarkers +} diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go new file mode 100644 index 0000000000..a136819ae1 --- /dev/null +++ b/optimizer/convert_expr.go @@ -0,0 +1,425 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/subquery" + "github.com/pingcap/tidb/model" + "strings" +) + +func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression.Expression, error) { + expr.Accept(converter) + if converter.err != nil { + return nil, errors.Trace(converter.err) + } + return converter.exprMap[expr], nil +} + +// expressionConverter converts ast expression to +// old expression for transition state. +type expressionConverter struct { + exprMap map[ast.Node]expression.Expression + paramMarkers paramMarkers + err error +} + +func newExpressionConverter() *expressionConverter { + return &expressionConverter{ + exprMap: map[ast.Node]expression.Expression{}, + } +} + +// Enter implements ast.Visitor interface. +func (c *expressionConverter) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +// Leave implements ast.Visitor interface. +func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + c.value(v) + case *ast.BetweenExpr: + c.between(v) + case *ast.BinaryOperationExpr: + c.binaryOperation(v) + case *ast.WhenClause: + c.whenClause(v) + case *ast.CaseExpr: + c.caseExpr(v) + case *ast.SubqueryExpr: + c.subquery(v) + case *ast.CompareSubqueryExpr: + c.compareSubquery(v) + case *ast.ColumnNameExpr: + c.columnNameExpr(v) + case *ast.DefaultExpr: + c.defaultExpr(v) + case *ast.IdentifierExpr: + c.identifier(v) + case *ast.ExistsSubqueryExpr: + c.existsSubquery(v) + case *ast.PatternInExpr: + c.patternIn(v) + case *ast.IsNullExpr: + c.isNull(v) + case *ast.IsTruthExpr: + c.isTruth(v) + case *ast.PatternLikeExpr: + c.patternLike(v) + case *ast.ParamMarkerExpr: + c.paramMarker(v) + case *ast.ParenthesesExpr: + c.parentheses(v) + case *ast.PositionExpr: + c.position(v) + case *ast.PatternRegexpExpr: + c.patternRegexp(v) + case *ast.RowExpr: + c.row(v) + case *ast.UnaryOperationExpr: + c.unaryOperation(v) + case *ast.ValuesExpr: + c.values(v) + case *ast.VariableExpr: + c.variable(v) + case *ast.FuncCallExpr: + c.funcCall(v) + case *ast.FuncExtractExpr: + c.funcExtract(v) + case *ast.FuncConvertExpr: + c.funcConvert(v) + case *ast.FuncCastExpr: + c.funcCast(v) + case *ast.FuncSubstringExpr: + c.funcSubstring(v) + case *ast.FuncLocateExpr: + c.funcLocate(v) + case *ast.FuncTrimExpr: + c.funcTrim(v) + case *ast.FuncDateArithExpr: + c.funcDateArith(v) + case *ast.AggregateFuncExpr: + c.aggregateFunc(v) + } + return in, c.err == nil +} + +func (c *expressionConverter) value(v *ast.ValueExpr) { + c.exprMap[v] = expression.Value{Val: v.GetValue()} +} + +func (c *expressionConverter) between(v *ast.BetweenExpr) { + oldExpr := c.exprMap[v.Expr] + oldLo := c.exprMap[v.Left] + oldHi := c.exprMap[v.Right] + oldBetween, err := expression.NewBetween(oldExpr, oldLo, oldHi, v.Not) + if err != nil { + c.err = err + return + } + c.exprMap[v] = oldBetween +} + +func (c *expressionConverter) binaryOperation(v *ast.BinaryOperationExpr) { + oldeLeft := c.exprMap[v.L] + oldRight := c.exprMap[v.R] + oldBinop := expression.NewBinaryOperation(v.Op, oldeLeft, oldRight) + c.exprMap[v] = oldBinop +} + +func (c *expressionConverter) whenClause(v *ast.WhenClause) { + oldExpr := c.exprMap[v.Expr] + oldResult := c.exprMap[v.Result] + oldWhenClause := &expression.WhenClause{Expr: oldExpr, Result: oldResult} + c.exprMap[v] = oldWhenClause +} + +func (c *expressionConverter) caseExpr(v *ast.CaseExpr) { + oldValue := c.exprMap[v.Value] + oldWhenClauses := make([]*expression.WhenClause, len(v.WhenClauses)) + for i, val := range v.WhenClauses { + oldWhenClauses[i] = c.exprMap[val].(*expression.WhenClause) + } + oldElse := c.exprMap[v.ElseClause] + oldCaseExpr := &expression.FunctionCase{ + Value: oldValue, + WhenClauses: oldWhenClauses, + ElseClause: oldElse, + } + c.exprMap[v] = oldCaseExpr +} + +func (c *expressionConverter) subquery(v *ast.SubqueryExpr) { + oldSubquery := &subquery.SubQuery{} + switch x := v.Query.(type) { + case *ast.SelectStmt: + oldSelect, err := convertSelect(c, x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(c, x) + if err != nil { + c.err = err + return + } + oldSubquery.Stmt = oldUnion + } + c.exprMap[v] = oldSubquery +} + +func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { + expr := c.exprMap[v.L] + subquery := c.exprMap[v.R] + oldCmpSubquery := expression.NewCompareSubQuery(v.Op, expr, subquery.(expression.SubQuery), v.All) + c.exprMap[v] = oldCmpSubquery +} + +func joinColumnName(columnName *ast.ColumnName) string { + var originStrs []string + if columnName.Schema.O != "" { + originStrs = append(originStrs, columnName.Schema.O) + } + if columnName.Table.O != "" { + originStrs = append(originStrs, columnName.Table.O) + } + originStrs = append(originStrs, columnName.Name.O) + return strings.Join(originStrs, ".") +} + +func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { + ident := &expression.Ident{} + ident.CIStr = model.NewCIStr(joinColumnName(v.Name)) + c.exprMap[v] = ident +} + +func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { + oldDefault := &expression.Default{} + if v.Name != nil { + oldDefault.Name = joinColumnName(v.Name) + } + c.exprMap[v] = oldDefault +} + +func (c *expressionConverter) identifier(v *ast.IdentifierExpr) { + oldIdent := &expression.Ident{} + oldIdent.CIStr = v.Name + c.exprMap[v] = oldIdent +} + +func (c *expressionConverter) existsSubquery(v *ast.ExistsSubqueryExpr) { + subquery := c.exprMap[v.Sel].(expression.SubQuery) + c.exprMap[v] = &expression.ExistsSubQuery{Sel: subquery} +} + +func (c *expressionConverter) patternIn(v *ast.PatternInExpr) { + oldPatternIn := &expression.PatternIn{Not: v.Not} + if v.Sel != nil { + oldPatternIn.Sel = c.exprMap[v.Sel].(expression.SubQuery) + } + oldPatternIn.Expr = c.exprMap[v.Expr] + if v.List != nil { + oldPatternIn.List = make([]expression.Expression, len(v.List)) + for i, v := range v.List { + oldPatternIn.List[i] = c.exprMap[v] + } + } + c.exprMap[v] = oldPatternIn +} + +func (c *expressionConverter) isNull(v *ast.IsNullExpr) { + oldIsNull := &expression.IsNull{Not: v.Not} + oldIsNull.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsNull +} + +func (c *expressionConverter) isTruth(v *ast.IsTruthExpr) { + oldIsTruth := &expression.IsTruth{Not: v.Not, True: v.True} + oldIsTruth.Expr = c.exprMap[v.Expr] + c.exprMap[v] = oldIsTruth +} + +func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) { + oldPatternLike := &expression.PatternLike{ + Not: v.Not, + Escape: v.Escape, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternLike +} + +func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) { + if c.exprMap[v] == nil { + c.exprMap[v] = &expression.ParamMarker{} + c.paramMarkers = append(c.paramMarkers, v) + } +} + +func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) { + oldExpr := c.exprMap[v.Expr] + c.exprMap[v] = &expression.PExpr{Expr: oldExpr} +} + +func (c *expressionConverter) position(v *ast.PositionExpr) { + c.exprMap[v] = &expression.Position{N: v.N, Name: v.Name} +} + +func (c *expressionConverter) patternRegexp(v *ast.PatternRegexpExpr) { + oldPatternRegexp := &expression.PatternRegexp{ + Not: v.Not, + Expr: c.exprMap[v.Expr], + Pattern: c.exprMap[v.Pattern], + } + c.exprMap[v] = oldPatternRegexp +} + +func (c *expressionConverter) row(v *ast.RowExpr) { + oldRow := &expression.Row{} + oldRow.Values = make([]expression.Expression, len(v.Values)) + for i, val := range v.Values { + oldRow.Values[i] = c.exprMap[val] + } + c.exprMap[v] = oldRow +} + +func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { + oldUnary := &expression.UnaryOperation{ + Op: v.Op, + V: c.exprMap[v.V], + } + c.exprMap[v] = oldUnary +} + +func (c *expressionConverter) values(v *ast.ValuesExpr) { + nameStr := joinColumnName(v.Column) + c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} +} + +func (c *expressionConverter) variable(v *ast.VariableExpr) { + c.exprMap[v] = &expression.Variable{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + } +} + +func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) { + oldCall := &expression.Call{ + F: v.FnName, + } + oldCall.Args = make([]expression.Expression, len(v.Args)) + for i, val := range v.Args { + oldCall.Args[i] = c.exprMap[val] + } + c.exprMap[v] = oldCall +} + +func (c *expressionConverter) funcExtract(v *ast.FuncExtractExpr) { + oldExtract := &expression.Extract{Unit: v.Unit} + oldExtract.Date = c.exprMap[v.Date] + c.exprMap[v] = oldExtract +} + +func (c *expressionConverter) funcConvert(v *ast.FuncConvertExpr) { + c.exprMap[v] = &expression.FunctionConvert{ + Expr: c.exprMap[v.Expr], + Charset: v.Charset, + } +} + +func (c *expressionConverter) funcCast(v *ast.FuncCastExpr) { + oldCast := &expression.FunctionCast{ + Expr: c.exprMap[v.Expr], + Tp: v.Tp, + } + switch v.FunctionType { + case ast.CastBinaryOperator: + oldCast.FunctionType = expression.BinaryOperator + case ast.CastConvertFunction: + oldCast.FunctionType = expression.ConvertFunction + case ast.CastFunction: + oldCast.FunctionType = expression.CastFunction + } + c.exprMap[v] = oldCast +} + +func (c *expressionConverter) funcSubstring(v *ast.FuncSubstringExpr) { + oldSubstring := &expression.FunctionSubstring{ + Len: c.exprMap[v.Len], + Pos: c.exprMap[v.Pos], + StrExpr: c.exprMap[v.StrExpr], + } + c.exprMap[v] = oldSubstring +} + +func (c *expressionConverter) funcLocate(v *ast.FuncLocateExpr) { + oldLocate := &expression.FunctionLocate{ + Pos: c.exprMap[v.Pos], + Str: c.exprMap[v.Str], + SubStr: c.exprMap[v.SubStr], + } + c.exprMap[v] = oldLocate +} + +func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) { + oldTrim := &expression.FunctionTrim{ + Str: c.exprMap[v.Str], + RemStr: c.exprMap[v.RemStr], + } + switch v.Direction { + case ast.TrimBoth: + oldTrim.Direction = expression.TrimBoth + case ast.TrimBothDefault: + oldTrim.Direction = expression.TrimBothDefault + case ast.TrimLeading: + oldTrim.Direction = expression.TrimLeading + case ast.TrimTrailing: + oldTrim.Direction = expression.TrimTrailing + } + c.exprMap[v] = oldTrim +} + +func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) { + oldDateArith := &expression.DateArith{ + Unit: v.Unit, + Date: c.exprMap[v.Date], + Interval: c.exprMap[v.Interval], + } + switch v.Op { + case ast.DateAdd: + oldDateArith.Op = expression.DateAdd + case ast.DateSub: + oldDateArith.Op = expression.DateSub + } + c.exprMap[v] = oldDateArith +} + +func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) { + oldAggregate := &expression.Call{ + F: v.F, + Distinct: v.Distinct, + } + for _, val := range v.Args { + oldAggregate.Args = append(oldAggregate.Args, c.exprMap[val]) + } + c.exprMap[v] = oldAggregate +} diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go new file mode 100644 index 0000000000..df769237f1 --- /dev/null +++ b/optimizer/convert_stmt.go @@ -0,0 +1,1024 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/field" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/parser/coldef" + "github.com/pingcap/tidb/rset/rsets" + "github.com/pingcap/tidb/stmt" + "github.com/pingcap/tidb/stmt/stmts" + "github.com/pingcap/tidb/table" +) + +func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expression.Assignment, error) { + oldAssign := &expression.Assignment{ + ColName: joinColumnName(v.Column), + } + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldAssign.Expr = oldExpr + return oldAssign, nil +} + +func convertInsert(converter *expressionConverter, v *ast.InsertStmt) (stmt.Statement, error) { + insertValues := stmts.InsertValues{} + insertValues.Priority = v.Priority + tableName := v.Table.TableRefs.Left.(*ast.TableSource).Source.(*ast.TableName) + insertValues.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + for _, val := range v.Columns { + insertValues.ColNames = append(insertValues.ColNames, joinColumnName(val)) + } + for _, row := range v.Lists { + var oldRow []expression.Expression + for _, val := range row { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldRow = append(oldRow, oldExpr) + } + insertValues.Lists = append(insertValues.Lists, oldRow) + } + for _, assign := range v.Setlist { + oldAssign, err := convertAssignment(converter, assign) + if err != nil { + return nil, errors.Trace(err) + } + insertValues.Setlist = append(insertValues.Setlist, oldAssign) + } + + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + insertValues.Sel, err = convertSelect(converter, x) + case *ast.UnionStmt: + insertValues.Sel, err = convertUnion(converter, x) + } + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Replace { + return &stmts.ReplaceIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), + }, nil + } + oldInsert := &stmts.InsertIntoStmt{ + InsertValues: insertValues, + Text: v.Text(), + } + for _, onDup := range v.OnDuplicate { + oldOnDup, err := convertAssignment(converter, onDup) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, oldOnDup) + } + return oldInsert, nil +} + +func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { + oldDelete := &stmts.DeleteStmt{ + BeforeFrom: v.BeforeFrom, + Ignore: v.Ignore, + LowPriority: v.LowPriority, + MultiTable: v.MultiTable, + Quick: v.Quick, + Text: v.Text(), + } + oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Refs = oldRefs + for _, val := range v.Tables { + tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} + oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + } + if v.Where != nil { + oldDelete.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Order != nil { + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Order = oldOrderBy + } + if v.Limit != nil { + oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} + } + return oldDelete, nil +} + +func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { + oldUpdate := &stmts.UpdateStmt{ + Ignore: v.Ignore, + MultipleTable: v.MultipleTable, + LowPriority: v.LowPriority, + Text: v.Text(), + } + var err error + oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + if v.Where != nil { + oldUpdate.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + for _, val := range v.List { + var oldAssign *expression.Assignment + oldAssign, err = convertAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.List = append(oldUpdate.List, oldAssign) + } + if v.Order != nil { + oldOrderBy, err := convertOrderBy(converter, v.Order) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.Order = oldOrderBy + } + if v.Limit != nil { + oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} + } + return oldUpdate, nil +} + +func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { + oldSelect := &stmts.SelectStmt{ + Distinct: s.Distinct, + Text: s.Text(), + } + oldSelect.Fields = make([]*field.Field, len(s.Fields.Fields)) + for i, val := range s.Fields.Fields { + oldField := &field.Field{} + oldField.AsName = val.AsName.O + var err error + if val.Expr != nil { + oldField.Expr, err = convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + // TODO: handle parenthesesed column name expression, which should not set AsName. + if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { + oldField.AsName = val.Text() + } + } else if val.WildCard != nil { + str := "*" + if val.WildCard.Table.O != "" { + str = val.WildCard.Table.O + ".*" + if val.WildCard.Schema.O != "" { + str = val.WildCard.Schema.O + "." + str + } + } + oldField.Expr = &expression.Ident{CIStr: model.NewCIStr(str)} + } + oldSelect.Fields[i] = oldField + } + var err error + if s.From != nil { + oldSelect.From, err = convertJoin(converter, s.From.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Where != nil { + oldSelect.Where = &rsets.WhereRset{} + oldSelect.Where.Expr, err = convertExpr(converter, s.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + + if s.GroupBy != nil { + oldSelect.GroupBy, err = convertGroupBy(converter, s.GroupBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Having != nil { + oldSelect.Having, err = convertHaving(converter, s.Having) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.OrderBy != nil { + oldSelect.OrderBy, err = convertOrderBy(converter, s.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + } + if s.Limit != nil { + if s.Limit.Offset > 0 { + oldSelect.Offset = &rsets.OffsetRset{Count: s.Limit.Offset} + } + if s.Limit.Count > 0 { + oldSelect.Limit = &rsets.LimitRset{Count: s.Limit.Count} + } + } + switch s.LockTp { + case ast.SelectLockForUpdate: + oldSelect.Lock = coldef.SelectLockForUpdate + case ast.SelectLockInShareMode: + oldSelect.Lock = coldef.SelectLockInShareMode + case ast.SelectLockNone: + oldSelect.Lock = coldef.SelectLockNone + } + return oldSelect, nil +} + +func convertUnion(converter *expressionConverter, u *ast.UnionStmt) (*stmts.UnionStmt, error) { + oldUnion := &stmts.UnionStmt{ + Text: u.Text(), + } + oldUnion.Selects = make([]*stmts.SelectStmt, len(u.Selects)) + oldUnion.Distincts = make([]bool, len(u.Selects)-1) + if u.Distinct { + for i := range oldUnion.Distincts { + oldUnion.Distincts[i] = true + } + } + for i, val := range u.Selects { + oldSelect, err := convertSelect(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.Selects[i] = oldSelect + } + if u.OrderBy != nil { + oldOrderBy, err := convertOrderBy(converter, u.OrderBy) + if err != nil { + return nil, errors.Trace(err) + } + oldUnion.OrderBy = oldOrderBy + } + if u.Limit != nil { + if u.Limit.Offset > 0 { + oldUnion.Offset = &rsets.OffsetRset{Count: u.Limit.Offset} + } + if u.Limit.Count > 0 { + oldUnion.Limit = &rsets.LimitRset{Count: u.Limit.Count} + } + } + // Union order by can not converted to old because it is pushed to select statements. + return oldUnion, nil +} + +func convertJoin(converter *expressionConverter, join *ast.Join) (*rsets.JoinRset, error) { + oldJoin := &rsets.JoinRset{} + switch join.Tp { + case ast.CrossJoin: + oldJoin.Type = rsets.CrossJoin + case ast.LeftJoin: + oldJoin.Type = rsets.LeftJoin + case ast.RightJoin: + oldJoin.Type = rsets.RightJoin + } + switch l := join.Left.(type) { + case *ast.Join: + oldLeft, err := convertJoin(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + case *ast.TableSource: + oldLeft, err := convertTableSource(converter, l) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Left = oldLeft + } + + switch r := join.Right.(type) { + case *ast.Join: + oldRight, err := convertJoin(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case *ast.TableSource: + oldRight, err := convertTableSource(converter, r) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.Right = oldRight + case nil: + } + + if join.On != nil { + oldOn, err := convertExpr(converter, join.On.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldJoin.On = oldOn + } + return oldJoin, nil +} + +func convertTableSource(converter *expressionConverter, ts *ast.TableSource) (*rsets.TableSource, error) { + oldTs := &rsets.TableSource{} + oldTs.Name = ts.AsName.O + switch src := ts.Source.(type) { + case *ast.TableName: + oldTs.Source = table.Ident{Schema: src.Schema, Name: src.Name} + case *ast.SelectStmt: + oldSelect, err := convertSelect(converter, src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldSelect + case *ast.UnionStmt: + oldUnion, err := convertUnion(converter, src) + if err != nil { + return nil, errors.Trace(err) + } + oldTs.Source = oldUnion + } + return oldTs, nil +} + +func convertGroupBy(converter *expressionConverter, gb *ast.GroupByClause) (*rsets.GroupByRset, error) { + oldGroupBy := &rsets.GroupByRset{ + By: make([]expression.Expression, len(gb.Items)), + } + for i, val := range gb.Items { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldGroupBy.By[i] = oldExpr + } + return oldGroupBy, nil +} + +func convertHaving(converter *expressionConverter, having *ast.HavingClause) (*rsets.HavingRset, error) { + oldHaving := &rsets.HavingRset{} + oldExpr, err := convertExpr(converter, having.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldHaving.Expr = oldExpr + return oldHaving, nil +} + +func convertOrderBy(converter *expressionConverter, orderBy *ast.OrderByClause) (*rsets.OrderByRset, error) { + oldOrderBy := &rsets.OrderByRset{} + oldOrderBy.By = make([]rsets.OrderByItem, len(orderBy.Items)) + for i, val := range orderBy.Items { + oldByItemExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldOrderBy.By[i].Expr = oldByItemExpr + oldOrderBy.By[i].Asc = !val.Desc + } + return oldOrderBy, nil +} + +func convertCreateDatabase(converter *expressionConverter, v *ast.CreateDatabaseStmt) (*stmts.CreateDatabaseStmt, error) { + oldCreateDatabase := &stmts.CreateDatabaseStmt{ + IfNotExists: v.IfNotExists, + Name: v.Name, + Text: v.Text(), + } + if len(v.Options) != 0 { + oldCreateDatabase.Opt = &coldef.CharsetOpt{} + for _, val := range v.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + oldCreateDatabase.Opt.Chs = val.Value + case ast.DatabaseOptionCollate: + oldCreateDatabase.Opt.Col = val.Value + } + } + } + return oldCreateDatabase, nil +} + +func convertDropDatabase(converter *expressionConverter, v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, error) { + return &stmts.DropDatabaseStmt{ + IfExists: v.IfExists, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (*coldef.ConstraintOpt, error) { + oldColumnOpt := &coldef.ConstraintOpt{} + switch v.Tp { + case ast.ColumnOptionAutoIncrement: + oldColumnOpt.Tp = coldef.ConstrAutoIncrement + case ast.ColumnOptionComment: + oldColumnOpt.Tp = coldef.ConstrComment + case ast.ColumnOptionDefaultValue: + oldColumnOpt.Tp = coldef.ConstrDefaultValue + case ast.ColumnOptionIndex: + oldColumnOpt.Tp = coldef.ConstrIndex + case ast.ColumnOptionKey: + oldColumnOpt.Tp = coldef.ConstrKey + case ast.ColumnOptionFulltext: + oldColumnOpt.Tp = coldef.ConstrFulltext + case ast.ColumnOptionNotNull: + oldColumnOpt.Tp = coldef.ConstrNotNull + case ast.ColumnOptionNoOption: + oldColumnOpt.Tp = coldef.ConstrNoConstr + case ast.ColumnOptionOnUpdate: + oldColumnOpt.Tp = coldef.ConstrOnUpdate + case ast.ColumnOptionPrimaryKey: + oldColumnOpt.Tp = coldef.ConstrPrimaryKey + case ast.ColumnOptionNull: + oldColumnOpt.Tp = coldef.ConstrNull + case ast.ColumnOptionUniq: + oldColumnOpt.Tp = coldef.ConstrUniq + case ast.ColumnOptionUniqIndex: + oldColumnOpt.Tp = coldef.ConstrUniqIndex + case ast.ColumnOptionUniqKey: + oldColumnOpt.Tp = coldef.ConstrUniqKey + } + if v.Expr != nil { + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldColumnOpt.Evalue = oldExpr + } + return oldColumnOpt, nil +} + +func convertColumnDef(converter *expressionConverter, v *ast.ColumnDef) (*coldef.ColumnDef, error) { + oldColDef := &coldef.ColumnDef{ + Name: v.Name.Name.O, + Tp: v.Tp, + } + for _, val := range v.Options { + oldOpt, err := convertColumnOption(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldColDef.Constraints = append(oldColDef.Constraints, oldOpt) + } + return oldColDef, nil +} + +func convertIndexColNames(v []*ast.IndexColName) (out []*coldef.IndexColName) { + for _, val := range v { + oldIndexColKey := &coldef.IndexColName{ + ColumnName: val.Column.Name.O, + Length: val.Length, + } + out = append(out, oldIndexColKey) + } + return +} + +func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*coldef.TableConstraint, error) { + oldConstraint := &coldef.TableConstraint{ConstrName: v.Name} + switch v.Tp { + case ast.ConstraintNoConstraint: + oldConstraint.Tp = coldef.ConstrNoConstr + case ast.ConstraintPrimaryKey: + oldConstraint.Tp = coldef.ConstrPrimaryKey + case ast.ConstraintKey: + oldConstraint.Tp = coldef.ConstrKey + case ast.ConstraintIndex: + oldConstraint.Tp = coldef.ConstrIndex + case ast.ConstraintUniq: + oldConstraint.Tp = coldef.ConstrUniq + case ast.ConstraintUniqKey: + oldConstraint.Tp = coldef.ConstrUniqKey + case ast.ConstraintUniqIndex: + oldConstraint.Tp = coldef.ConstrUniqIndex + case ast.ConstraintForeignKey: + oldConstraint.Tp = coldef.ConstrForeignKey + case ast.ConstraintFulltext: + oldConstraint.Tp = coldef.ConstrFulltext + } + oldConstraint.Keys = convertIndexColNames(v.Keys) + if v.Refer != nil { + oldConstraint.Refer = &coldef.ReferenceDef{ + TableIdent: table.Ident{Schema: v.Refer.Table.Schema, Name: v.Refer.Table.Name}, + IndexColNames: convertIndexColNames(v.Refer.IndexColNames), + } + } + return oldConstraint, nil +} + +func convertCreateTable(converter *expressionConverter, v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { + oldCreateTable := &stmts.CreateTableStmt{ + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + IfNotExists: v.IfNotExists, + Text: v.Text(), + } + for _, val := range v.Cols { + oldColDef, err := convertColumnDef(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Cols = append(oldCreateTable.Cols, oldColDef) + } + for _, val := range v.Constraints { + oldConstr, err := convertConstraint(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Constraints = append(oldCreateTable.Constraints, oldConstr) + } + if len(v.Options) != 0 { + oldTableOpt := &coldef.TableOption{} + for _, val := range v.Options { + switch val.Tp { + case ast.TableOptionEngine: + oldTableOpt.Engine = val.StrValue + case ast.TableOptionCharset: + oldTableOpt.Charset = val.StrValue + case ast.TableOptionCollate: + oldTableOpt.Collate = val.StrValue + case ast.TableOptionAutoIncrement: + oldTableOpt.AutoIncrement = val.UintValue + } + } + oldCreateTable.Opt = oldTableOpt + } + return oldCreateTable, nil +} + +func convertDropTable(converter *expressionConverter, v *ast.DropTableStmt) (*stmts.DropTableStmt, error) { + oldDropTable := &stmts.DropTableStmt{ + IfExists: v.IfExists, + Text: v.Text(), + } + oldDropTable.TableIdents = make([]table.Ident, len(v.Tables)) + for i, val := range v.Tables { + oldDropTable.TableIdents[i] = table.Ident{ + Schema: val.Schema, + Name: val.Name, + } + } + return oldDropTable, nil +} + +func convertCreateIndex(converter *expressionConverter, v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) { + oldCreateIndex := &stmts.CreateIndexStmt{ + IndexName: v.IndexName, + Unique: v.Unique, + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + } + oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) + for i, val := range v.IndexColNames { + oldIndexColName := &coldef.IndexColName{ + ColumnName: joinColumnName(val.Column), + Length: val.Length, + } + oldCreateIndex.IndexColNames[i] = oldIndexColName + } + return oldCreateIndex, nil +} + +func convertDropIndex(converter *expressionConverter, v *ast.DropIndexStmt) (*stmts.DropIndexStmt, error) { + return &stmts.DropIndexStmt{ + IfExists: v.IfExists, + IndexName: v.IndexName, + Text: v.Text(), + }, nil +} + +func convertAlterTableSpec(converter *expressionConverter, v *ast.AlterTableSpec) (*ddl.AlterSpecification, error) { + oldAlterSpec := &ddl.AlterSpecification{ + Name: v.Name, + } + switch v.Tp { + case ast.AlterTableAddConstraint: + oldAlterSpec.Action = ddl.AlterAddConstr + case ast.AlterTableAddColumn: + oldAlterSpec.Action = ddl.AlterAddColumn + case ast.AlterTableDropColumn: + oldAlterSpec.Action = ddl.AlterDropColumn + case ast.AlterTableDropForeignKey: + oldAlterSpec.Action = ddl.AlterDropForeignKey + case ast.AlterTableDropIndex: + oldAlterSpec.Action = ddl.AlterDropIndex + case ast.AlterTableDropPrimaryKey: + oldAlterSpec.Action = ddl.AlterDropPrimaryKey + case ast.AlterTableOption: + oldAlterSpec.Action = ddl.AlterTableOpt + } + if v.Column != nil { + oldColDef, err := convertColumnDef(converter, v.Column) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Column = oldColDef + } + if v.Position != nil { + oldAlterSpec.Position = &ddl.ColumnPosition{} + switch v.Position.Tp { + case ast.ColumnPositionNone: + oldAlterSpec.Position.Type = ddl.ColumnPositionNone + case ast.ColumnPositionFirst: + oldAlterSpec.Position.Type = ddl.ColumnPositionFirst + case ast.ColumnPositionAfter: + oldAlterSpec.Position.Type = ddl.ColumnPositionAfter + } + if v.Position.RelativeColumn != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.Position.RelativeColumn) + } + } + if v.DropColumn != nil { + oldAlterSpec.Name = joinColumnName(v.DropColumn) + } + if v.Constraint != nil { + oldConstraint, err := convertConstraint(converter, v.Constraint) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Constraint = oldConstraint + } + for _, val := range v.Options { + oldOpt := &coldef.TableOpt{ + StrValue: val.StrValue, + UintValue: val.UintValue, + } + switch val.Tp { + case ast.TableOptionNone: + oldOpt.Tp = coldef.TblOptNone + case ast.TableOptionEngine: + oldOpt.Tp = coldef.TblOptEngine + case ast.TableOptionCharset: + oldOpt.Tp = coldef.TblOptCharset + case ast.TableOptionCollate: + oldOpt.Tp = coldef.TblOptCollate + case ast.TableOptionAutoIncrement: + oldOpt.Tp = coldef.TblOptAutoIncrement + case ast.TableOptionComment: + oldOpt.Tp = coldef.TblOptComment + case ast.TableOptionAvgRowLength: + oldOpt.Tp = coldef.TblOptAvgRowLength + case ast.TableOptionCheckSum: + oldOpt.Tp = coldef.TblOptCheckSum + case ast.TableOptionCompression: + oldOpt.Tp = coldef.TblOptCompression + case ast.TableOptionConnection: + oldOpt.Tp = coldef.TblOptConnection + case ast.TableOptionPassword: + oldOpt.Tp = coldef.TblOptPassword + case ast.TableOptionKeyBlockSize: + oldOpt.Tp = coldef.TblOptKeyBlockSize + case ast.TableOptionMaxRows: + oldOpt.Tp = coldef.TblOptMaxRows + case ast.TableOptionMinRows: + oldOpt.Tp = coldef.TblOptMinRows + } + oldAlterSpec.TableOpts = append(oldAlterSpec.TableOpts, oldOpt) + } + return oldAlterSpec, nil +} + +func convertAlterTable(converter *expressionConverter, v *ast.AlterTableStmt) (*stmts.AlterTableStmt, error) { + oldAlterTable := &stmts.AlterTableStmt{ + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertAlterTableSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterTable.Specs = append(oldAlterTable.Specs, oldSpec) + } + return oldAlterTable, nil +} + +func convertTruncateTable(converter *expressionConverter, v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { + return &stmts.TruncateTableStmt{ + TableIdent: table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + }, + Text: v.Text(), + }, nil +} + +func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { + oldExplain := &stmts.ExplainStmt{ + Text: v.Text(), + } + var err error + switch x := v.Stmt.(type) { + case *ast.SelectStmt: + oldExplain.S, err = convertSelect(converter, x) + case *ast.UpdateStmt: + oldExplain.S, err = convertUpdate(converter, x) + case *ast.DeleteStmt: + oldExplain.S, err = convertDelete(converter, x) + case *ast.InsertStmt: + oldExplain.S, err = convertInsert(converter, x) + case *ast.ShowStmt: + oldExplain.S, err = convertShow(converter, x) + } + if err != nil { + return nil, errors.Trace(err) + } + return oldExplain, nil +} + +func convertPrepare(converter *expressionConverter, v *ast.PrepareStmt) (*stmts.PreparedStmt, error) { + oldPrepare := &stmts.PreparedStmt{ + InPrepare: true, + Name: v.Name, + SQLText: v.SQLText, + Text: v.Text(), + } + if v.SQLVar != nil { + oldSQLVar, err := convertExpr(converter, v.SQLVar) + if err != nil { + return nil, errors.Trace(err) + } + oldPrepare.SQLVar = oldSQLVar.(*expression.Variable) + } + return oldPrepare, nil +} + +func convertDeallocate(converter *expressionConverter, v *ast.DeallocateStmt) (*stmts.DeallocateStmt, error) { + return &stmts.DeallocateStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + }, nil +} + +func convertExecute(converter *expressionConverter, v *ast.ExecuteStmt) (*stmts.ExecuteStmt, error) { + oldExec := &stmts.ExecuteStmt{ + ID: v.ID, + Name: v.Name, + Text: v.Text(), + } + oldExec.UsingVars = make([]expression.Expression, len(v.UsingVars)) + for i, val := range v.UsingVars { + oldVar, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldExec.UsingVars[i] = oldVar + } + return oldExec, nil +} + +func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowStmt, error) { + oldShow := &stmts.ShowStmt{ + DBName: v.DBName, + Flag: v.Flag, + Full: v.Full, + GlobalScope: v.GlobalScope, + Text: v.Text(), + } + if v.Table != nil { + oldShow.TableIdent = table.Ident{ + Schema: v.Table.Schema, + Name: v.Table.Name, + } + } + if v.Column != nil { + oldShow.ColumnName = joinColumnName(v.Column) + } + if v.Where != nil { + oldWhere, err := convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Where = oldWhere + } + if v.Pattern != nil { + oldPattern, err := convertExpr(converter, v.Pattern) + if err != nil { + return nil, errors.Trace(err) + } + oldShow.Pattern = oldPattern.(*expression.PatternLike) + } + switch v.Tp { + case ast.ShowCharset: + oldShow.Target = stmt.ShowCharset + case ast.ShowCollation: + oldShow.Target = stmt.ShowCollation + case ast.ShowColumns: + oldShow.Target = stmt.ShowColumns + case ast.ShowCreateTable: + oldShow.Target = stmt.ShowCreateTable + case ast.ShowDatabases: + oldShow.Target = stmt.ShowDatabases + case ast.ShowTables: + oldShow.Target = stmt.ShowTables + case ast.ShowEngines: + oldShow.Target = stmt.ShowEngines + case ast.ShowVariables: + oldShow.Target = stmt.ShowVariables + case ast.ShowWarnings: + oldShow.Target = stmt.ShowWarnings + case ast.ShowNone: + oldShow.Target = stmt.ShowNone + } + return oldShow, nil +} + +func convertBegin(converter *expressionConverter, v *ast.BeginStmt) (*stmts.BeginStmt, error) { + return &stmts.BeginStmt{ + Text: v.Text(), + }, nil +} + +func convertCommit(converter *expressionConverter, v *ast.CommitStmt) (*stmts.CommitStmt, error) { + return &stmts.CommitStmt{ + Text: v.Text(), + }, nil +} + +func convertRollback(converter *expressionConverter, v *ast.RollbackStmt) (*stmts.RollbackStmt, error) { + return &stmts.RollbackStmt{ + Text: v.Text(), + }, nil +} + +func convertUse(converter *expressionConverter, v *ast.UseStmt) (*stmts.UseStmt, error) { + return &stmts.UseStmt{ + DBName: v.DBName, + Text: v.Text(), + }, nil +} + +func convertVariableAssignment(converter *expressionConverter, v *ast.VariableAssignment) (*stmts.VariableAssignment, error) { + oldValue, err := convertExpr(converter, v.Value) + if err != nil { + return nil, errors.Trace(err) + } + + return &stmts.VariableAssignment{ + IsGlobal: v.IsGlobal, + IsSystem: v.IsSystem, + Name: v.Name, + Value: oldValue, + Text: v.Text(), + }, nil +} + +func convertSet(converter *expressionConverter, v *ast.SetStmt) (*stmts.SetStmt, error) { + oldSet := &stmts.SetStmt{ + Text: v.Text(), + Variables: make([]*stmts.VariableAssignment, len(v.Variables)), + } + for i, val := range v.Variables { + oldAssign, err := convertVariableAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldSet.Variables[i] = oldAssign + } + return oldSet, nil +} + +func convertSetCharset(converter *expressionConverter, v *ast.SetCharsetStmt) (*stmts.SetCharsetStmt, error) { + return &stmts.SetCharsetStmt{ + Charset: v.Charset, + Collate: v.Collate, + Text: v.Text(), + }, nil +} + +func convertSetPwd(converter *expressionConverter, v *ast.SetPwdStmt) (*stmts.SetPwdStmt, error) { + return &stmts.SetPwdStmt{ + User: v.User, + Password: v.Password, + Text: v.Text(), + }, nil +} + +func convertUserSpec(converter *expressionConverter, v *ast.UserSpec) (*coldef.UserSpecification, error) { + oldSpec := &coldef.UserSpecification{ + User: v.User, + } + if v.AuthOpt != nil { + oldAuthOpt := &coldef.AuthOption{ + AuthString: v.AuthOpt.AuthString, + ByAuthString: v.AuthOpt.ByAuthString, + HashString: v.AuthOpt.HashString, + } + oldSpec.AuthOpt = oldAuthOpt + } + return oldSpec, nil +} + +func convertCreateUser(converter *expressionConverter, v *ast.CreateUserStmt) (*stmts.CreateUserStmt, error) { + oldCreateUser := &stmts.CreateUserStmt{ + IfNotExists: v.IfNotExists, + Text: v.Text(), + } + for _, val := range v.Specs { + oldSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateUser.Specs = append(oldCreateUser.Specs, oldSpec) + } + return oldCreateUser, nil +} + +func convertDo(converter *expressionConverter, v *ast.DoStmt) (*stmts.DoStmt, error) { + oldDo := &stmts.DoStmt{ + Text: v.Text(), + Exprs: make([]expression.Expression, len(v.Exprs)), + } + for i, val := range v.Exprs { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldDo.Exprs[i] = oldExpr + } + return oldDo, nil +} + +func convertPrivElem(converter *expressionConverter, v *ast.PrivElem) (*coldef.PrivElem, error) { + oldPrim := &coldef.PrivElem{ + Priv: v.Priv, + } + for _, val := range v.Cols { + oldPrim.Cols = append(oldPrim.Cols, joinColumnName(val)) + } + return oldPrim, nil +} + +func convertGrant(converter *expressionConverter, v *ast.GrantStmt) (*stmts.GrantStmt, error) { + oldGrant := &stmts.GrantStmt{ + Text: v.Text(), + } + for _, val := range v.Privs { + oldPrim, err := convertPrivElem(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Privs = append(oldGrant.Privs, oldPrim) + } + switch v.ObjectType { + case ast.ObjectTypeNone: + oldGrant.ObjectType = coldef.ObjectTypeNone + case ast.ObjectTypeTable: + oldGrant.ObjectType = coldef.ObjectTypeTable + } + if v.Level != nil { + oldGrantLevel := &coldef.GrantLevel{ + DBName: v.Level.DBName, + TableName: v.Level.TableName, + } + switch v.Level.Level { + case ast.GrantLevelDB: + oldGrantLevel.Level = coldef.GrantLevelDB + case ast.GrantLevelGlobal: + oldGrantLevel.Level = coldef.GrantLevelGlobal + case ast.GrantLevelNone: + oldGrantLevel.Level = coldef.GrantLevelNone + case ast.GrantLevelTable: + oldGrantLevel.Level = coldef.GrantLevelTable + } + oldGrant.Level = oldGrantLevel + } + for _, val := range v.Users { + oldUserSpec, err := convertUserSpec(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldGrant.Users = append(oldGrant.Users, oldUserSpec) + } + return oldGrant, nil +} diff --git a/optimizer/evaluator.go b/optimizer/evaluator.go new file mode 100644 index 0000000000..fffb5b53b3 --- /dev/null +++ b/optimizer/evaluator.go @@ -0,0 +1,317 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/plan" +) + +// Evaluator is a ast visitor that evaluates an expression. +type Evaluator struct { + // columnMap is the map from ColumnName to the position of the rowStack. + // It is used to find the value of the column. + columnMap map[*ast.ColumnName]position + // rowStack is the current row values while scanning. + // It should be updated after scaned a new row. + rowStack []*plan.Row + + // the map from AggregateFuncExpr to aggregator index. + aggregateMap map[*ast.AggregateFuncExpr]int + + // aggregators for the current row + // only outer query aggregate functions are handled. + aggregators []Aggregator + + // when aggregation phase is done, the input is + aggregateDone bool + + err error +} + +type position struct { + stackOffset int + fieldList bool + columnOffset int +} + +// Enter implements ast.Visitor interface. +func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return +} + +// Leave implements ast.Visitor interface. +func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ValueExpr: + ok = true + case *ast.BetweenExpr: + ok = e.between(v) + case *ast.BinaryOperationExpr: + ok = e.binaryOperation(v) + case *ast.WhenClause: + ok = e.whenClause(v) + case *ast.CaseExpr: + ok = e.caseExpr(v) + case *ast.SubqueryExpr: + ok = e.subquery(v) + case *ast.CompareSubqueryExpr: + ok = e.compareSubquery(v) + case *ast.ColumnNameExpr: + ok = e.columnName(v) + case *ast.DefaultExpr: + ok = e.defaultExpr(v) + case *ast.IdentifierExpr: + ok = e.identifier(v) + case *ast.ExistsSubqueryExpr: + ok = e.existsSubquery(v) + case *ast.PatternInExpr: + ok = e.patternIn(v) + case *ast.IsNullExpr: + ok = e.isNull(v) + case *ast.IsTruthExpr: + ok = e.isTruth(v) + case *ast.PatternLikeExpr: + ok = e.patternLike(v) + case *ast.ParamMarkerExpr: + ok = e.paramMarker(v) + case *ast.ParenthesesExpr: + ok = e.parentheses(v) + case *ast.PositionExpr: + ok = e.position(v) + case *ast.PatternRegexpExpr: + ok = e.patternRegexp(v) + case *ast.RowExpr: + ok = e.row(v) + case *ast.UnaryOperationExpr: + ok = e.unaryOperation(v) + case *ast.ValuesExpr: + ok = e.values(v) + case *ast.VariableExpr: + ok = e.variable(v) + case *ast.FuncCallExpr: + ok = e.funcCall(v) + case *ast.FuncExtractExpr: + ok = e.funcExtract(v) + case *ast.FuncConvertExpr: + ok = e.funcConvert(v) + case *ast.FuncCastExpr: + ok = e.funcCast(v) + case *ast.FuncSubstringExpr: + ok = e.funcSubstring(v) + case *ast.FuncLocateExpr: + ok = e.funcLocate(v) + case *ast.FuncTrimExpr: + ok = e.funcTrim(v) + case *ast.AggregateFuncExpr: + ok = e.aggregateFunc(v) + } + out = in + return +} + +func checkAllOneColumn(exprs ...ast.ExprNode) bool { + for _, expr := range exprs { + switch v := expr.(type) { + case *ast.RowExpr: + return false + case *ast.SubqueryExpr: + if len(v.Query.GetResultFields()) != 1 { + return false + } + } + } + return true +} + +func (e *Evaluator) between(v *ast.BetweenExpr) bool { + if !checkAllOneColumn(v.Expr, v.Left, v.Right) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + + var l, r ast.ExprNode + op := opcode.AndAnd + + if v.Not { + // v < lv || v > rv + op = opcode.OrOr + l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right} + } else { + // v >= lv && v <= rv + l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left} + r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right} + } + + ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r} + ret.Accept(e) + return e.err == nil +} + +func columnCount(e ast.ExprNode) (int, error) { + switch x := e.(type) { + case *ast.RowExpr: + n := len(x.Values) + if n <= 1 { + return 0, errors.Errorf("Operand should contain >= 2 columns for Row") + } + return n, nil + case *ast.SubqueryExpr: + return len(x.Query.GetResultFields()), nil + default: + return 1, nil + } +} + +func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error { + l, err := columnCount(e) + if err != nil { + return errors.Trace(err) + } + var n int + for _, arg := range args { + n, err = columnCount(arg) + if err != nil { + return errors.Trace(err) + } + + if n != l { + return errors.Errorf("Operand should contain %d column(s)", l) + } + } + return nil +} + +func (e *Evaluator) whenClause(v *ast.WhenClause) bool { + return true +} + +func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool { + return true +} + +func (e *Evaluator) subquery(v *ast.SubqueryExpr) bool { + return true +} + +func (e *Evaluator) compareSubquery(v *ast.CompareSubqueryExpr) bool { + return true +} + +func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool { + return true +} + +func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool { + return true +} + +func (e *Evaluator) identifier(v *ast.IdentifierExpr) bool { + return true +} + +func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool { + return true +} + +func (e *Evaluator) patternIn(v *ast.PatternInExpr) bool { + return true +} + +func (e *Evaluator) isNull(v *ast.IsNullExpr) bool { + return true +} + +func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool { + return true +} + +func (e *Evaluator) patternLike(v *ast.PatternLikeExpr) bool { + return true +} + +func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool { + return true +} + +func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool { + return true +} + +func (e *Evaluator) position(v *ast.PositionExpr) bool { + return true +} + +func (e *Evaluator) patternRegexp(v *ast.PatternRegexpExpr) bool { + return true +} + +func (e *Evaluator) row(v *ast.RowExpr) bool { + return true +} + +func (e *Evaluator) unaryOperation(v *ast.UnaryOperationExpr) bool { + return true +} + +func (e *Evaluator) values(v *ast.ValuesExpr) bool { + return true +} + +func (e *Evaluator) variable(v *ast.VariableExpr) bool { + return true +} + +func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool { + return true +} + +func (e *Evaluator) funcExtract(v *ast.FuncExtractExpr) bool { + return true +} + +func (e *Evaluator) funcConvert(v *ast.FuncConvertExpr) bool { + return true +} + +func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool { + return true +} + +func (e *Evaluator) funcSubstring(v *ast.FuncSubstringExpr) bool { + return true +} + +func (e *Evaluator) funcLocate(v *ast.FuncLocateExpr) bool { + return true +} + +func (e *Evaluator) funcTrim(v *ast.FuncTrimExpr) bool { + return true +} + +func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { + idx := e.aggregateMap[v] + aggr := e.aggregators[idx] + if e.aggregateDone { + v.SetValue(aggr.Output()) + return true + } + // TODO: currently only single argument aggregate functions are supported. + e.err = aggr.Input(v.Args[0].GetValue()) + return e.err == nil +} diff --git a/optimizer/evaluator_binop.go b/optimizer/evaluator_binop.go new file mode 100644 index 0000000000..0cda326144 --- /dev/null +++ b/optimizer/evaluator_binop.go @@ -0,0 +1,527 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/types" +) + +const ( + zeroI64 int64 = 0 + oneI64 int64 = 1 +) + +func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool { + // all operands must have same column. + if e.err = hasSameColumnCount(o.L, o.R); e.err != nil { + return false + } + + // row constructor only supports comparison operation. + switch o.Op { + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + default: + if !checkAllOneColumn(o.L) { + e.err = errors.Errorf("Operand should contain 1 column(s)") + return false + } + } + + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + + switch o.Op { + case opcode.AndAnd, opcode.OrOr, opcode.LogicXor: + return e.handleLogicOperation(o) + case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ: + return e.handleComparisonOp(o) + case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor: + // TODO: MySQL doesn't support and not, we should remove it later. + return e.handleBitOp(o) + case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv: + return e.handleArithmeticOp(o) + default: + panic("should never happen") + } +} + +func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool { + leftVal, err := types.Convert(o.L.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + rightVal, err := types.Convert(o.R.GetValue(), o.GetType()) + if err != nil { + e.err = err + return false + } + if leftVal == nil || rightVal == nil { + o.SetValue(nil) + return true + } + var boolVal bool + switch o.Op { + case opcode.AndAnd: + boolVal = leftVal != zeroI64 && rightVal != zeroI64 + case opcode.OrOr: + boolVal = leftVal != zeroI64 || rightVal != zeroI64 + case opcode.LogicXor: + boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64) + default: + panic("should never happen") + } + if boolVal { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + if types.IsNil(a) || types.IsNil(b) { + // for <=>, if a and b are both nil, return true. + // if a or b is nil, return false. + if o.Op == opcode.NullEQ { + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + } else { + o.SetValue(nil) + } + return true + } + + n, err := types.Compare(a, b) + if err != nil { + e.err = err + return false + } + + r, err := getCompResult(o.Op, n) + if err != nil { + e.err = err + return false + } + if r { + o.SetValue(oneI64) + } else { + o.SetValue(zeroI64) + } + return true +} + +func getCompResult(op opcode.Op, value int) (bool, error) { + switch op { + case opcode.LT: + return value < 0, nil + case opcode.LE: + return value <= 0, nil + case opcode.GE: + return value >= 0, nil + case opcode.GT: + return value > 0, nil + case opcode.EQ: + return value == 0, nil + case opcode.NE: + return value != 0, nil + case opcode.NullEQ: + return value == 0, nil + default: + return false, errors.Errorf("invalid op %v in comparision operation", op) + } +} + +func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool { + a, b := types.Coerce(o.L.GetValue(), o.R.GetValue()) + + if types.IsNil(a) || types.IsNil(b) { + o.SetValue(nil) + return true + } + + x, err := types.ToInt64(a) + if err != nil { + e.err = err + return false + } + + y, err := types.ToInt64(b) + if err != nil { + e.err = err + return false + } + + // use a int64 for bit operator, return uint64 + switch o.Op { + case opcode.And: + o.SetValue(uint64(x & y)) + case opcode.Or: + o.SetValue(uint64(x | y)) + case opcode.Xor: + o.SetValue(uint64(x ^ y)) + case opcode.RightShift: + o.SetValue(uint64(x) >> uint64(y)) + case opcode.LeftShift: + o.SetValue(uint64(x) << uint64(y)) + default: + e.err = errors.Errorf("invalid op %v in bit operation", o.Op) + return false + } + return true +} + +func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool { + a, err := coerceArithmetic(o.L.GetValue()) + if err != nil { + e.err = err + return false + } + + b, err := coerceArithmetic(o.R.GetValue()) + if err != nil { + e.err = err + return false + } + a, b = types.Coerce(a, b) + + if a == nil || b == nil { + // TODO: for <=>, if a and b are both nil, return true + o.SetValue(nil) + return true + } + + // TODO: support logic division DIV + var result interface{} + switch o.Op { + case opcode.Plus: + result, e.err = computePlus(a, b) + case opcode.Minus: + result, e.err = computeMinus(a, b) + case opcode.Mul: + result, e.err = computeMul(a, b) + case opcode.Div: + result, e.err = computeDiv(a, b) + case opcode.Mod: + result, e.err = computeMod(a, b) + case opcode.IntDiv: + result, e.err = computeIntDiv(a, b) + default: + e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op) + return false + } + o.SetValue(result) + return e.err == nil +} + +func computePlus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.AddInt64(x, y) + case uint64: + return types.AddInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.AddInteger(x, y) + case uint64: + return types.AddUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x + y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Add(y), nil + } + } + return types.InvOp2(a, b, opcode.Plus) +} + +func computeMinus(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.SubInt64(x, y) + case uint64: + return types.SubIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + return types.SubUintWithInt(x, y) + case uint64: + return types.SubUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x - y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Sub(y), nil + } + } + + return types.InvOp2(a, b, opcode.Minus) +} + +func computeMul(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + return types.MulInt64(x, y) + case uint64: + return types.MulInteger(y, x) + } + case uint64: + switch y := b.(type) { + case int64: + return types.MulInteger(x, y) + case uint64: + return types.MulUint64(x, y) + } + case float64: + switch y := b.(type) { + case float64: + return x * y, nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + return x.Mul(y), nil + } + } + + return types.InvOp2(a, b, opcode.Mul) +} + +func computeDiv(a, b interface{}) (interface{}, error) { + // MySQL support integer divison Div and division operator / + // we use opcode.Div for division operator and will use another for integer division later. + // for division operator, we will use float64 for calculation. + switch x := a.(type) { + case float64: + y, err := types.ToFloat64(b) + if err != nil { + return nil, err + } + + if y == 0 { + return nil, nil + } + + return x / y, nil + default: + // the scale of the result is the scale of the first operand plus + // the value of the div_precision_increment system variable (which is 4 by default) + // we will use 4 here + + xa, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + xb, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + if f, _ := xb.Float64(); f == 0 { + // division by zero return null + return nil, nil + } + + return xa.Div(xb), nil + } +} + +func computeMod(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return x % y, nil + case uint64: + if y == 0 { + return nil, nil + } else if x < 0 { + // first is int64, return int64. + return -int64(uint64(-x) % y), nil + } + return int64(uint64(x) % y), nil + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } else if y < 0 { + // first is uint64, return uint64. + return uint64(x % uint64(-y)), nil + } + return x % uint64(y), nil + case uint64: + if y == 0 { + return nil, nil + } + return x % y, nil + } + case float64: + switch y := b.(type) { + case float64: + if y == 0 { + return nil, nil + } + return math.Mod(x, y), nil + } + case mysql.Decimal: + switch y := b.(type) { + case mysql.Decimal: + xf, _ := x.Float64() + yf, _ := y.Float64() + if yf == 0 { + return nil, nil + } + return math.Mod(xf, yf), nil + } + } + + return types.InvOp2(a, b, opcode.Mod) +} + +func computeIntDiv(a, b interface{}) (interface{}, error) { + switch x := a.(type) { + case int64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivInt64(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return types.DivIntWithUint(x, y) + } + case uint64: + switch y := b.(type) { + case int64: + if y == 0 { + return nil, nil + } + return types.DivUintWithInt(x, y) + case uint64: + if y == 0 { + return nil, nil + } + return x / y, nil + } + } + + // if any is none integer, use decimal to calculate + x, err := types.ToDecimal(a) + if err != nil { + return nil, err + } + + y, err := types.ToDecimal(b) + if err != nil { + return nil, err + } + + if f, _ := y.Float64(); f == 0 { + return nil, nil + } + + return x.Div(y).IntPart(), nil +} + +func coerceArithmetic(a interface{}) (interface{}, error) { + switch x := a.(type) { + case string: + // MySQL will convert string to float for arithmetic operation + f, err := types.StrToFloat(x) + if err != nil { + return nil, err + } + return f, err + case mysql.Time: + // if time has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case mysql.Duration: + // if duration has no precision, return int64 + v := x.ToNumber() + if x.Fsp == 0 { + return v.IntPart(), nil + } + return v, nil + case []byte: + // []byte is the same as string, converted to float for arithmetic operator. + f, err := types.StrToFloat(string(x)) + if err != nil { + return nil, err + } + return f, err + case mysql.Hex: + return x.ToNumber(), nil + case mysql.Bit: + return x.ToNumber(), nil + case mysql.Enum: + return x.ToNumber(), nil + case mysql.Set: + return x.ToNumber(), nil + default: + return x, nil + } +} diff --git a/optimizer/infobinder.go b/optimizer/infobinder.go new file mode 100644 index 0000000000..0ad59bceb6 --- /dev/null +++ b/optimizer/infobinder.go @@ -0,0 +1,441 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import ( + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" +) + +// InfoBinder binds schema information for table name and column name and set result fields +// for ResetSetNode. +// We need to know which table a table name refers to, which column a column name refers to. +// +// In general, a reference can only refer to information that are available for it. +// So children elements are visited in the order that previous elements make information +// available for following elements. +// +// During visiting, information are collected and stored in binderContext. +// When we enter a sub query, a new binderContext is pushed to the contextStack, so sub query +// information can overwrite outer query information. When we look up for a column reference, +// we look up from top to bottom in the contextStack. +type InfoBinder struct { + Info infoschema.InfoSchema + DefaultSchema model.CIStr + Err error + + contextStack []*binderContext +} + +// binderContext stores information that table name and column name +// can be bind to. +type binderContext struct { + /* For Select Statement. */ + // table map to lookup and check table name conflict. + tableMap map[string]int + // tableSources collected in from clause. + tables []*ast.TableSource + // result fields collected in select field list. + fieldList []*ast.ResultField + // result fields collected in group by clause. + groupBy []*ast.ResultField + + // The join node stack is used by on condition to find out + // available tables to reference. On condition can only + // refer to tables involved in current join. + joinNodeStack []*ast.Join + + // When visiting TableRefs, tables in this context are not available + // because it is being collected. + inTableRefs bool + // When visiting on conditon only tables in current join node are available. + inOnCondition bool + // When visiting field list, fieldList in this context are not available. + inFieldList bool + // When visiting group by, groupBy fields are not available. + inGroupBy bool + // When visiting having, only fieldList and groupBy fields are available. + inHaving bool +} + +// currentContext gets the current binder context. +func (sb *InfoBinder) currentContext() *binderContext { + stackLen := len(sb.contextStack) + if stackLen == 0 { + return nil + } + return sb.contextStack[stackLen-1] +} + +// pushContext is called when we enter a statement. +func (sb *InfoBinder) pushContext() { + sb.contextStack = append(sb.contextStack, &binderContext{ + tableMap: map[string]int{}, + }) +} + +// popContext is called when we leave a statement. +func (sb *InfoBinder) popContext() { + sb.contextStack = sb.contextStack[:len(sb.contextStack)-1] +} + +// pushJoin is called when we enter a join node. +func (sb *InfoBinder) pushJoin(j *ast.Join) { + ctx := sb.currentContext() + ctx.joinNodeStack = append(ctx.joinNodeStack, j) +} + +// popJoin is called when we leave a join node. +func (sb *InfoBinder) popJoin() { + ctx := sb.currentContext() + ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1] +} + +// Enter implements ast.Visitor interface. +func (sb *InfoBinder) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { + switch v := inNode.(type) { + case *ast.SelectStmt: + sb.pushContext() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = true + case *ast.Join: + sb.pushJoin(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = true + case *ast.FieldList: + sb.currentContext().inFieldList = true + case *ast.GroupByClause: + sb.currentContext().inGroupBy = true + case *ast.HavingClause: + sb.currentContext().inHaving = true + case *ast.InsertStmt: + sb.pushContext() + case *ast.DeleteStmt: + sb.pushContext() + case *ast.UpdateStmt: + sb.pushContext() + } + return inNode, false +} + +// Leave implements ast.Visitor interface. +func (sb *InfoBinder) Leave(inNode ast.Node) (node ast.Node, ok bool) { + switch v := inNode.(type) { + case *ast.TableName: + sb.handleTableName(v) + case *ast.ColumnName: + sb.handleColumnName(v) + case *ast.TableSource: + sb.handleTableSource(v) + case *ast.OnCondition: + sb.currentContext().inOnCondition = false + case *ast.Join: + sb.handleJoin(v) + sb.popJoin() + case *ast.TableRefsClause: + sb.currentContext().inTableRefs = false + case *ast.FieldList: + sb.handleFieldList(v) + sb.currentContext().inFieldList = false + case *ast.GroupByClause: + sb.currentContext().inGroupBy = false + case *ast.HavingClause: + sb.currentContext().inHaving = false + case *ast.SelectStmt: + v.SetResultFields(sb.currentContext().fieldList) + sb.popContext() + case *ast.InsertStmt: + sb.popContext() + case *ast.DeleteStmt: + sb.popContext() + case *ast.UpdateStmt: + sb.popContext() + } + return inNode, sb.Err == nil +} + +// handleTableName looks up and bind the schema information for table name +// and set result fields for table name. +func (sb *InfoBinder) handleTableName(tn *ast.TableName) { + if tn.Schema.L == "" { + tn.Schema = sb.DefaultSchema + } + table, err := sb.Info.TableByName(tn.Schema, tn.Name) + if err != nil { + sb.Err = err + return + } + tn.TableInfo = table.Meta() + dbInfo, _ := sb.Info.SchemaByName(tn.Schema) + tn.DBInfo = dbInfo + + rfs := make([]*ast.ResultField, len(tn.TableInfo.Columns)) + for i, v := range tn.TableInfo.Columns { + rfs[i] = &ast.ResultField{ + Column: v, + Table: tn.TableInfo, + DBName: tn.Schema, + } + } + tn.SetResultFields(rfs) + return +} + +// handleTableSources checks name duplication +// and puts the table source in current binderContext. +func (sb *InfoBinder) handleTableSource(ts *ast.TableSource) { + for _, v := range ts.GetResultFields() { + v.TableAsName = ts.AsName + } + var name string + if ts.AsName.L != "" { + name = ts.AsName.L + } else { + tableName := ts.Source.(*ast.TableName) + name = sb.tableUniqueName(tableName.Schema, tableName.Name) + } + ctx := sb.currentContext() + if _, ok := ctx.tableMap[name]; ok { + sb.Err = errors.Errorf("duplicated table/alias name %s", name) + return + } + ctx.tableMap[name] = len(ctx.tables) + ctx.tables = append(ctx.tables, ts) + return +} + +// handleJoin sets result fields for join. +func (sb *InfoBinder) handleJoin(j *ast.Join) { + if j.Right == nil { + j.SetResultFields(j.Left.GetResultFields()) + return + } + leftLen := len(j.Left.GetResultFields()) + rightLen := len(j.Right.GetResultFields()) + rfs := make([]*ast.ResultField, leftLen+rightLen) + copy(rfs, j.Left.GetResultFields()) + copy(rfs[leftLen:], j.Right.GetResultFields()) + j.SetResultFields(rfs) +} + +// handleColumnName looks up and binds schema information to +// the column name. +func (sb *InfoBinder) handleColumnName(cn *ast.ColumnName) { + ctx := sb.currentContext() + if ctx.inOnCondition { + // In on condition, only tables within current join is available. + sb.bindColumnNameInOnCondition(cn) + return + } + + // Try to bind the column name form top to bottom in the context stack. + for i := len(sb.contextStack) - 1; i >= 0; i-- { + if sb.bindColumnNameInContext(sb.contextStack[i], cn) { + // Column is already bound or encountered an error. + return + } + } + sb.Err = errors.Errorf("Unknown column %s", cn.Name.L) +} + +// bindColumnNameInContext looks up and binds schema information for a column with the ctx. +func (sb *InfoBinder) bindColumnNameInContext(ctx *binderContext, cn *ast.ColumnName) (done bool) { + if cn.Table.L == "" { + // If qualified table name is not specified in column name, the column name may be ambiguous, + // We need to iterate over all tables and + } + + if ctx.inTableRefs { + // In TableRefsClause, column reference only in join on condition which is handled before. + return false + } + if ctx.inFieldList { + // only bind column using tables. + return sb.bindColumnInTableSources(cn, ctx.tables) + } + if ctx.inGroupBy { + // field list first, then tables. + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + return sb.bindColumnInTableSources(cn, ctx.tables) + } + // column name in other places can be looked up in the same order. + if sb.bindColumnInResultFields(cn, ctx.groupBy) { + return true + } + if sb.bindColumnInResultFields(cn, ctx.fieldList) { + return true + } + + // tables is not available for having clause. + if !ctx.inHaving { + return sb.bindColumnInTableSources(cn, ctx.tables) + } + return false +} + +// bindColumnNameInOnCondition looks up for column name in current join, and +// binds the schema information. +func (sb *InfoBinder) bindColumnNameInOnCondition(cn *ast.ColumnName) { + ctx := sb.currentContext() + join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1] + tableSources := appendTableSources(nil, join) + if !sb.bindColumnInTableSources(cn, tableSources) { + sb.Err = errors.Errorf("unkown column name %s", cn.Name.O) + } +} + +func (sb *InfoBinder) bindColumnInTableSources(cn *ast.ColumnName, tableSources []*ast.TableSource) (done bool) { + var matchedResultField *ast.ResultField + if cn.Table.L != "" { + var matchedTable ast.ResultSetNode + for _, ts := range tableSources { + if cn.Table.L == ts.AsName.L { + // different table name. + matchedTable = ts + break + } + if tn, ok := ts.Source.(*ast.TableName); ok { + if cn.Table.L == tn.Name.L { + matchedTable = ts + } + } + } + if matchedTable != nil { + resultFields := matchedTable.GetResultFields() + for _, rf := range resultFields { + if rf.ColumnAsName.L == cn.Name.L || rf.Column.Name.L == cn.Name.L { + // bind column. + matchedResultField = rf + break + } + } + } + } else { + for _, ts := range tableSources { + rfs := ts.GetResultFields() + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return true + } + matchedResultField = rf + } + } + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +func (sb *InfoBinder) bindColumnInResultFields(cn *ast.ColumnName, rfs []*ast.ResultField) bool { + var matchedResultField *ast.ResultField + for _, rf := range rfs { + matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L + matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L + if matchAsName || matchColumnName { + if matchedResultField != nil { + sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O) + return false + } + matchedResultField = rf + } + } + if matchedResultField != nil { + // bind column. + cn.ColumnInfo = matchedResultField.Column + cn.TableInfo = matchedResultField.Table + return true + } + return false +} + +// handleFieldList expands wild card field and set fieldList in current context. +func (sb *InfoBinder) handleFieldList(fieldList *ast.FieldList) { + var resultFields []*ast.ResultField + for _, v := range fieldList.Fields { + resultFields = append(resultFields, sb.createResultFields(v)...) + } + sb.currentContext().fieldList = resultFields +} + +// createResultFields creates result field list for a single select field. +func (sb *InfoBinder) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) { + ctx := sb.currentContext() + if field.WildCard != nil { + if len(ctx.tables) == 0 { + sb.Err = errors.Errorf("No table used.") + return + } + if field.WildCard.Table.L == "" { + for _, v := range ctx.tables { + rfs = append(rfs, v.GetResultFields()...) + } + } else { + name := sb.tableUniqueName(field.WildCard.Schema, field.WildCard.Table) + tableIdx, ok := ctx.tableMap[name] + if !ok { + sb.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O) + } + rfs = ctx.tables[tableIdx].GetResultFields() + } + return + } + // The column is visited before so it must has been bound already. + rf := &ast.ResultField{ColumnAsName: field.AsName} + switch v := field.Expr.(type) { + case *ast.ColumnNameExpr: + rf.Column = v.Name.ColumnInfo + rf.Table = v.Name.TableInfo + rf.DBName = v.Name.Schema + default: + if field.AsName.L == "" { + rf.ColumnAsName.L = field.Expr.Text() + rf.ColumnAsName.O = rf.ColumnAsName.L + } + } + rfs = append(rfs, rf) + return +} + +func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) { + switch v := resultSetNode.(type) { + case *ast.TableSource: + out = append(in, v) + case *ast.Join: + out = appendTableSources(in, v.Left) + if v.Right != nil { + out = appendTableSources(out, v.Right) + } + } + return +} + +func (sb *InfoBinder) tableUniqueName(schema, table model.CIStr) string { + if schema.L != "" && schema.L != sb.DefaultSchema.L { + return schema.L + "." + table.L + } + return table.L +} diff --git a/optimizer/infobinder_test.go b/optimizer/infobinder_test.go new file mode 100644 index 0000000000..988d6dfaa0 --- /dev/null +++ b/optimizer/infobinder_test.go @@ -0,0 +1,82 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer_test + +import ( + "testing" + + . "github.com/pingcap/check" + "github.com/pingcap/tidb" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/optimizer" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/util/testkit" +) + +func TestT(t *testing.T) { + TestingT(t) +} + +var _ = Suite(&testInfoBinderSuite{}) + +type testInfoBinderSuite struct { +} + +type binderVerifier struct { + c *C +} + +func (bv *binderVerifier) Enter(node ast.Node) (ast.Node, bool) { + return node, false +} + +func (bv *binderVerifier) Leave(in ast.Node) (out ast.Node, ok bool) { + switch v := in.(type) { + case *ast.ColumnName: + bv.c.Assert(v.ColumnInfo, NotNil) + case *ast.TableName: + bv.c.Assert(v.TableInfo, NotNil) + } + return in, true +} + +func (ts *testInfoBinderSuite) TestInfoBinder(c *C) { + store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory) + c.Assert(err, IsNil) + defer store.Close() + testKit := testkit.NewTestKit(c, store) + testKit.MustExec("use test") + testKit.MustExec("create table t (c1 int, c2 int)") + domain := sessionctx.GetDomain(testKit.Se.(context.Context)) + + src := "SELECT c1 from t" + l := parser.NewLexer(src) + c.Assert(parser.YYParse(l), Equals, 0) + stmts := l.Stmts() + c.Assert(len(stmts), Equals, 1) + v := &optimizer.InfoBinder{ + Info: domain.InfoSchema(), + DefaultSchema: model.NewCIStr("test"), + } + selectStmt := stmts[0].(*ast.SelectStmt) + selectStmt.Accept(v) + + verifier := &binderVerifier{ + c: c, + } + selectStmt.Accept(verifier) +} diff --git a/optimizer/optimizer.go b/optimizer/typecomputer.go similarity index 60% rename from optimizer/optimizer.go rename to optimizer/typecomputer.go index 36ccd06c5f..f2cd35aff0 100644 --- a/optimizer/optimizer.go +++ b/optimizer/typecomputer.go @@ -13,20 +13,18 @@ package optimizer -import ( - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/stmt" -) +import "github.com/pingcap/tidb/ast" -// Compile compiles a ast.Node into a executable statement. -func Compile(node ast.Node) (stmt.Statement, error) { - switch v := node.(type) { - case *ast.SetStmt: - return compileSet(v) - } - return nil, nil +// typeComputer is an ast Visitor that +// computes result type for ast.ExprNode. +type typeComputer struct { + err error } -func compileSet(aset *ast.SetStmt) (stmt.Statement, error) { - return nil, nil +func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true } diff --git a/optimizer/validator.go b/optimizer/validator.go new file mode 100644 index 0000000000..f90a757ab8 --- /dev/null +++ b/optimizer/validator.go @@ -0,0 +1,30 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package optimizer + +import "github.com/pingcap/tidb/ast" + +// validator is an ast.Visitor that validates +// ast parsed from parser. +type validator struct { + err error +} + +func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + return in, false +} + +func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { + return in, true +} diff --git a/stmt/stmts/drop_test.go b/stmt/stmts/drop_test.go index 5d0a34ec0d..4a1de810c2 100644 --- a/stmt/stmts/drop_test.go +++ b/stmt/stmts/drop_test.go @@ -62,7 +62,7 @@ func (s *testStmtSuite) TestDropTable(c *C) { } func (s *testStmtSuite) TestDropIndex(c *C) { - testSQL := "drop index if exists drop_index;" + testSQL := "drop index if exists drop_index on t;" stmtList, err := tidb.Compile(s.ctx, testSQL) c.Assert(err, IsNil) diff --git a/stmt/stmts/show_test.go b/stmt/stmts/show_test.go index 5dc0c57828..e67813a46d 100644 --- a/stmt/stmts/show_test.go +++ b/stmt/stmts/show_test.go @@ -51,5 +51,5 @@ func (s *testStmtSuite) TestShow(c *C) { c.Assert(stmtList, HasLen, 1) testStmt, ok = stmtList[0].(*stmts.ShowStmt) c.Assert(ok, IsTrue) - c.Assert(testStmt.Pattern, NotNil) + c.Assert(testStmt.Pattern, NotNil, Commentf("S: %s", testStmt.Text)) } diff --git a/tidb.go b/tidb.go index 55581af782..27fcf5167a 100644 --- a/tidb.go +++ b/tidb.go @@ -26,14 +26,13 @@ import ( "github.com/juju/errors" "github.com/ngaut/log" - "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/ast/parser" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/field" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/optimizer" - "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" @@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) { rawStmt := l.Stmts() stmts := make([]stmt.Statement, len(rawStmt)) for i, v := range rawStmt { - if node, ok := v.(ast.Node); ok { - stm, err := optimizer.Compile(node) - if err != nil { - return nil, errors.Trace(err) - } - stmts[i] = stm - } else { - stmts[i] = v.(stmt.Statement) + compiler := &optimizer.Compiler{} + stm, err := compiler.Compile(v) + if err != nil { + return nil, errors.Trace(err) } + stmts[i] = stm } return stmts, nil } @@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express return nil, nil, nil } sm := sms[0] - return sm.(stmt.Statement), l.ParamList, nil + compiler := &optimizer.Compiler{} + statement, err := compiler.Compile(sm) + if err != nil { + return nil, nil, errors.Trace(err) + } + return statement, compiler.ParamMarkers(), nil } func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { diff --git a/tidb_test.go b/tidb_test.go index 0cb22ed701..8249ba21aa 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) { // conflict mustExecSQL(c, se1, "begin") rs, err := exec(c, se1, "select * from t where c1=11 for update") + c.Assert(err, IsNil) _, err = rs.Rows(-1, 0) mustExecSQL(c, se2, "begin")