Merge branch 'master' into coocood/use-terror

This commit is contained in:
Ewan Chou
2015-11-05 19:21:00 +08:00
33 changed files with 5525 additions and 7288 deletions

View File

@ -13,7 +13,7 @@ TARGET = ""
.PHONY: godep deps all build install parser clean todo test gotest interpreter server
all: godep parser ast-parser build test check
all: godep parser build test check
godep:
go get github.com/tools/godep
@ -31,7 +31,7 @@ parser:
goyacc -o /dev/null -xegen $$a parser/parser.y; \
goyacc -o parser/parser.go -xe $$a parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \
rm -f $$a; \
rm -f y.output
rm -f y.output
@if [ $(ARCH) = $(LINUX) ]; \
then \
@ -44,24 +44,6 @@ parser:
golex -o parser/scanner.go parser/scanner.l
ast-parser:
a=`mktemp temp.XXXXXX`; \
goyacc -o /dev/null -xegen $$a ast/parser/parser.y; \
goyacc -o ast/parser/parser.go -xe $$a ast/parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \
rm -f $$a; \
rm -f y.output
@if [ $(ARCH) = $(LINUX) ]; \
then \
sed -i -e 's|//line.*||' -e 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \
elif [ $(ARCH) = $(MAC) ]; \
then \
/usr/bin/sed -i "" 's|//line.*||' ast/parser/parser.go; \
/usr/bin/sed -i "" 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \
fi
golex -o ast/parser/scanner.go ast/parser/scanner.l
check:
go get github.com/golang/lint/golint

View File

@ -26,6 +26,11 @@ type Node interface {
// Accept accepts Visitor to visit itself.
// The returned node should replace original node.
// ok returns false to stop visiting.
//
// Implementation of this method should first call visitor.Enter,
// assign the returned node to its method receiver, if skipChildren returns true,
// children should be skipped. Otherwise, call its children in particular order that
// later elements depends on former elements. Finally, return visitor.Leave.
Accept(v Visitor) (node Node, ok bool)
// Text returns the original text of the element.
Text() string
@ -44,6 +49,10 @@ type ExprNode interface {
SetType(tp *types.FieldType)
// GetType gets the evaluation type of the expression.
GetType() *types.FieldType
// SetValue sets value to the expression.
SetValue(val interface{})
// GetValue gets value of the expression.
GetValue() interface{}
}
// FuncNode represents function call expression node.
@ -93,11 +102,12 @@ type ResultSetNode interface {
// Visitor visits a Node.
type Visitor interface {
// VisitEnter is called before children nodes is visited.
// The returned node must be the same type as the input node n.
// skipChildren returns true means children nodes should be skipped,
// this is useful when work is done in Enter and there is no need to visit children.
// ok returns false to stop visiting.
Enter(n Node) (skipChildren bool, ok bool)
Enter(n Node) (node Node, skipChildren bool)
// VisitLeave is called after children nodes has been visited.
// The returned node must be the same type as the input node n.
// ok returns false to stop visiting.
Leave(n Node) (node Node, ok bool)
}

View File

@ -62,7 +62,7 @@ func (dn *dmlNode) dmlStatement() {}
// Expression implementations should embed it in.
type exprNode struct {
node
tp *types.FieldType
types.DataItem
}
// IsStatic implements Expression interface.
@ -72,12 +72,22 @@ func (en *exprNode) IsStatic() bool {
// SetType implements Expression interface.
func (en *exprNode) SetType(tp *types.FieldType) {
en.tp = tp
en.Type = tp
}
// GetType implements Expression interface.
func (en *exprNode) GetType() *types.FieldType {
return en.tp
return en.Type
}
// SetValue implements Expression interface.
func (en *exprNode) SetValue(val interface{}) {
en.Data = val
}
// GetValue implements Expression interface.
func (en *exprNode) GetValue() interface{} {
return en.Data
}
type funcNode struct {

173
ast/cloner.go Normal file
View File

@ -0,0 +1,173 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package ast
import "fmt"
// Cloner is an ast visitor that clones a node.
type Cloner struct {
}
// Enter implements Visitor Enter interface.
func (c *Cloner) Enter(node Node) (Node, bool) {
return copyStruct(node), false
}
// Leave implements Visitor Leave interface.
func (c *Cloner) Leave(in Node) (out Node, ok bool) {
return in, true
}
// copyStruct copies a node's struct value, if the struct has slice member,
// make a new slice and copy old slice value to new slice.
func copyStruct(in Node) (out Node) {
switch v := in.(type) {
case *ValueExpr:
nv := *v
out = &nv
case *BetweenExpr:
nv := *v
out = &nv
case *BinaryOperationExpr:
nv := *v
out = &nv
case *WhenClause:
nv := *v
out = &nv
case *CaseExpr:
nv := *v
nv.WhenClauses = make([]*WhenClause, len(v.WhenClauses))
copy(nv.WhenClauses, v.WhenClauses)
out = &nv
case *SubqueryExpr:
nv := *v
out = &nv
case *CompareSubqueryExpr:
nv := *v
out = &nv
case *ColumnName:
nv := *v
out = &nv
case *ColumnNameExpr:
nv := *v
out = &nv
case *DefaultExpr:
nv := *v
out = &nv
case *IdentifierExpr:
nv := *v
out = &nv
case *ExistsSubqueryExpr:
nv := *v
out = &nv
case *PatternInExpr:
nv := *v
nv.List = make([]ExprNode, len(v.List))
copy(nv.List, v.List)
out = &nv
case *IsNullExpr:
nv := *v
out = &nv
case *IsTruthExpr:
nv := *v
out = &nv
case *PatternLikeExpr:
nv := *v
out = &nv
case *ParamMarkerExpr:
nv := *v
out = &nv
case *ParenthesesExpr:
nv := *v
out = &nv
case *PositionExpr:
nv := *v
out = &nv
case *PatternRegexpExpr:
nv := *v
out = &nv
case *RowExpr:
nv := *v
nv.Values = make([]ExprNode, len(v.Values))
copy(nv.Values, v.Values)
out = &nv
case *UnaryOperationExpr:
nv := *v
out = &nv
case *ValuesExpr:
nv := *v
out = &nv
case *VariableExpr:
nv := *v
out = &nv
case *Join:
nv := *v
out = &nv
case *TableName:
nv := *v
out = &nv
case *TableSource:
nv := *v
out = &nv
case *OnCondition:
nv := *v
out = &nv
case *WildCardField:
nv := *v
out = &nv
case *SelectField:
nv := *v
out = &nv
case *FieldList:
nv := *v
nv.Fields = make([]*SelectField, len(v.Fields))
copy(nv.Fields, v.Fields)
out = &nv
case *TableRefsClause:
nv := *v
out = &nv
case *ByItem:
nv := *v
out = &nv
case *GroupByClause:
nv := *v
nv.Items = make([]*ByItem, len(v.Items))
copy(nv.Items, v.Items)
out = &nv
case *HavingClause:
nv := *v
out = &nv
case *OrderByClause:
nv := *v
nv.Items = make([]*ByItem, len(v.Items))
copy(nv.Items, v.Items)
out = &nv
case *SelectStmt:
nv := *v
out = &nv
case *UnionClause:
nv := *v
out = &nv
case *UnionStmt:
nv := *v
nv.Selects = make([]*SelectStmt, len(v.Selects))
copy(nv.Selects, v.Selects)
out = &nv
default:
// We currently only handle expression and select statement.
// Will add more when we need to.
panic("unknown ast Node type " + fmt.Sprintf("%T", v))
}
return
}

40
ast/cloner_test.go Normal file
View File

@ -0,0 +1,40 @@
package ast
import (
"testing"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/parser/opcode"
)
func TestT(t *testing.T) {
TestingT(t)
}
var _ = Suite(&testClonerSuite{})
type testClonerSuite struct {
}
func (ts *testClonerSuite) TestCloner(c *C) {
cloner := &Cloner{}
a := &UnaryOperationExpr{
Op: opcode.Not,
V: &UnaryOperationExpr{V: NewValueExpr(true)},
}
b, ok := a.Accept(cloner)
c.Assert(ok, IsTrue)
a1 := a.V
b1 := b.(*UnaryOperationExpr).V
c.Assert(a1, Not(Equals), b1)
a2 := a1.(*UnaryOperationExpr).V
b2 := b1.(*UnaryOperationExpr).V
c.Assert(a2, Not(Equals), b2)
a3 := a2.(*ValueExpr)
b3 := b2.(*ValueExpr)
c.Assert(a3, Not(Equals), b3)
c.Assert(a3.GetValue(), Equals, true)
c.Assert(b3.GetValue(), Equals, true)
}

View File

@ -70,11 +70,13 @@ type CreateDatabaseStmt struct {
}
// Accept implements Node Accept interface.
func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cd); skipChildren {
return cd, ok
func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(cd)
n = newNod.(*CreateDatabaseStmt)
return v.Leave(n)
}
// DropDatabaseStmt is a statement to drop a database and all tables in the database.
@ -87,11 +89,13 @@ type DropDatabaseStmt struct {
}
// Accept implements Node Accept interface.
func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(dd); skipChildren {
return dd, ok
func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(dd)
n = newNod.(*DropDatabaseStmt)
return v.Leave(n)
}
// IndexColName is used for parsing index column name from SQL.
@ -103,16 +107,18 @@ type IndexColName struct {
}
// Accept implements Node Accept interface.
func (ic *IndexColName) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ic); skipChildren {
return ic, ok
func (n *IndexColName) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ic.Column.Accept(v)
n = newNod.(*IndexColName)
node, ok := n.Column.Accept(v)
if !ok {
return ic, false
return n, false
}
ic.Column = node.(*ColumnName)
return v.Leave(ic)
n.Column = node.(*ColumnName)
return v.Leave(n)
}
// ReferenceDef is used for parsing foreign key reference option from SQL.
@ -125,23 +131,25 @@ type ReferenceDef struct {
}
// Accept implements Node Accept interface.
func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(rd); skipChildren {
return rd, ok
func (n *ReferenceDef) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := rd.Table.Accept(v)
n = newNod.(*ReferenceDef)
node, ok := n.Table.Accept(v)
if !ok {
return rd, false
return n, false
}
rd.Table = node.(*TableName)
for i, val := range rd.IndexColNames {
n.Table = node.(*TableName)
for i, val := range n.IndexColNames {
node, ok = val.Accept(v)
if !ok {
return rd, false
return n, false
}
rd.IndexColNames[i] = node.(*IndexColName)
n.IndexColNames[i] = node.(*IndexColName)
}
return v.Leave(rd)
return v.Leave(n)
}
// ColumnOptionType is the type for ColumnOption.
@ -175,18 +183,20 @@ type ColumnOption struct {
}
// Accept implements Node Accept interface.
func (co *ColumnOption) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(co); skipChildren {
return co, ok
func (n *ColumnOption) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if co.Expr != nil {
node, ok := co.Expr.Accept(v)
n = newNod.(*ColumnOption)
if n.Expr != nil {
node, ok := n.Expr.Accept(v)
if !ok {
return co, false
return n, false
}
co.Expr = node.(ExprNode)
n.Expr = node.(ExprNode)
}
return v.Leave(co)
return v.Leave(n)
}
// ConstraintType is the type for Constraint.
@ -220,25 +230,27 @@ type Constraint struct {
}
// Accept implements Node Accept interface.
func (tc *Constraint) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(tc); skipChildren {
return tc, ok
func (n *Constraint) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range tc.Keys {
n = newNod.(*Constraint)
for i, val := range n.Keys {
node, ok := val.Accept(v)
if !ok {
return tc, false
return n, false
}
tc.Keys[i] = node.(*IndexColName)
n.Keys[i] = node.(*IndexColName)
}
if tc.Refer != nil {
node, ok := tc.Refer.Accept(v)
if n.Refer != nil {
node, ok := n.Refer.Accept(v)
if !ok {
return tc, false
return n, false
}
tc.Refer = node.(*ReferenceDef)
n.Refer = node.(*ReferenceDef)
}
return v.Leave(tc)
return v.Leave(n)
}
// ColumnDef is used for parsing column definition from SQL.
@ -251,23 +263,25 @@ type ColumnDef struct {
}
// Accept implements Node Accept interface.
func (cd *ColumnDef) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cd); skipChildren {
return cd, ok
func (n *ColumnDef) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := cd.Name.Accept(v)
n = newNod.(*ColumnDef)
node, ok := n.Name.Accept(v)
if !ok {
return cd, false
return n, false
}
cd.Name = node.(*ColumnName)
for i, val := range cd.Options {
n.Name = node.(*ColumnName)
for i, val := range n.Options {
node, ok := val.Accept(v)
if !ok {
return cd, false
return n, false
}
cd.Options[i] = node.(*ColumnOption)
n.Options[i] = node.(*ColumnOption)
}
return v.Leave(cd)
return v.Leave(n)
}
// CreateTableStmt is a statement to create a table.
@ -283,30 +297,32 @@ type CreateTableStmt struct {
}
// Accept implements Node Accept interface.
func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ct); skipChildren {
return ct, ok
func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ct.Table.Accept(v)
n = newNod.(*CreateTableStmt)
node, ok := n.Table.Accept(v)
if !ok {
return ct, false
return n, false
}
ct.Table = node.(*TableName)
for i, val := range ct.Cols {
n.Table = node.(*TableName)
for i, val := range n.Cols {
node, ok = val.Accept(v)
if !ok {
return ct, false
return n, false
}
ct.Cols[i] = node.(*ColumnDef)
n.Cols[i] = node.(*ColumnDef)
}
for i, val := range ct.Constraints {
for i, val := range n.Constraints {
node, ok = val.Accept(v)
if !ok {
return ct, false
return n, false
}
ct.Constraints[i] = node.(*Constraint)
n.Constraints[i] = node.(*Constraint)
}
return v.Leave(ct)
return v.Leave(n)
}
// DropTableStmt is a statement to drop one or more tables.
@ -319,18 +335,20 @@ type DropTableStmt struct {
}
// Accept implements Node Accept interface.
func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(dt); skipChildren {
return dt, ok
func (n *DropTableStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range dt.Tables {
n = newNod.(*DropTableStmt)
for i, val := range n.Tables {
node, ok := val.Accept(v)
if !ok {
return dt, false
return n, false
}
dt.Tables[i] = node.(*TableName)
n.Tables[i] = node.(*TableName)
}
return v.Leave(dt)
return v.Leave(n)
}
// CreateIndexStmt is a statement to create an index.
@ -345,23 +363,25 @@ type CreateIndexStmt struct {
}
// Accept implements Node Accept interface.
func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ci); skipChildren {
return ci, ok
func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ci.Table.Accept(v)
n = newNod.(*CreateIndexStmt)
node, ok := n.Table.Accept(v)
if !ok {
return ci, false
return n, false
}
ci.Table = node.(*TableName)
for i, val := range ci.IndexColNames {
n.Table = node.(*TableName)
for i, val := range n.IndexColNames {
node, ok = val.Accept(v)
if !ok {
return ci, false
return n, false
}
ci.IndexColNames[i] = node.(*IndexColName)
n.IndexColNames[i] = node.(*IndexColName)
}
return v.Leave(ci)
return v.Leave(n)
}
// DropIndexStmt is a statement to drop the index.
@ -375,16 +395,18 @@ type DropIndexStmt struct {
}
// Accept implements Node Accept interface.
func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(di); skipChildren {
return di, ok
func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := di.Table.Accept(v)
n = newNod.(*DropIndexStmt)
node, ok := n.Table.Accept(v)
if !ok {
return di, false
return n, false
}
di.Table = node.(*TableName)
return v.Leave(di)
n.Table = node.(*TableName)
return v.Leave(n)
}
// TableOptionType is the type for TableOption
@ -435,18 +457,20 @@ type ColumnPosition struct {
}
// Accept implements Node Accept interface.
func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cp); skipChildren {
return cp, ok
func (n *ColumnPosition) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if cp.RelativeColumn != nil {
node, ok := cp.RelativeColumn.Accept(v)
n = newNod.(*ColumnPosition)
if n.RelativeColumn != nil {
node, ok := n.RelativeColumn.Accept(v)
if !ok {
return cp, false
return n, false
}
cp.RelativeColumn = node.(*ColumnName)
n.RelativeColumn = node.(*ColumnName)
}
return v.Leave(cp)
return v.Leave(n)
}
// AlterTableType is the type for AlterTableSpec.
@ -474,44 +498,46 @@ type AlterTableSpec struct {
Constraint *Constraint
Options []*TableOption
Column *ColumnDef
ColumnName *ColumnName
DropColumn *ColumnName
Position *ColumnPosition
}
// Accept implements Node Accept interface.
func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(as); skipChildren {
return as, ok
func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if as.Constraint != nil {
node, ok := as.Constraint.Accept(v)
n = newNod.(*AlterTableSpec)
if n.Constraint != nil {
node, ok := n.Constraint.Accept(v)
if !ok {
return as, false
return n, false
}
as.Constraint = node.(*Constraint)
n.Constraint = node.(*Constraint)
}
if as.Column != nil {
node, ok := as.Column.Accept(v)
if n.Column != nil {
node, ok := n.Column.Accept(v)
if !ok {
return as, false
return n, false
}
as.Column = node.(*ColumnDef)
n.Column = node.(*ColumnDef)
}
if as.ColumnName != nil {
node, ok := as.ColumnName.Accept(v)
if n.DropColumn != nil {
node, ok := n.DropColumn.Accept(v)
if !ok {
return as, false
return n, false
}
as.ColumnName = node.(*ColumnName)
n.DropColumn = node.(*ColumnName)
}
if as.Position != nil {
node, ok := as.Position.Accept(v)
if n.Position != nil {
node, ok := n.Position.Accept(v)
if !ok {
return as, false
return n, false
}
as.Position = node.(*ColumnPosition)
n.Position = node.(*ColumnPosition)
}
return v.Leave(as)
return v.Leave(n)
}
// AlterTableStmt is a statement to change the structure of a table.
@ -524,23 +550,25 @@ type AlterTableStmt struct {
}
// Accept implements Node Accept interface.
func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(at); skipChildren {
return at, ok
func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := at.Table.Accept(v)
n = newNod.(*AlterTableStmt)
node, ok := n.Table.Accept(v)
if !ok {
return at, false
return n, false
}
at.Table = node.(*TableName)
for i, val := range at.Specs {
n.Table = node.(*TableName)
for i, val := range n.Specs {
node, ok = val.Accept(v)
if !ok {
return at, false
return n, false
}
at.Specs[i] = node.(*AlterTableSpec)
n.Specs[i] = node.(*AlterTableSpec)
}
return v.Leave(at)
return v.Leave(n)
}
// TruncateTableStmt is a statement to empty a table completely.
@ -552,14 +580,16 @@ type TruncateTableStmt struct {
}
// Accept implements Node Accept interface.
func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ts); skipChildren {
return ts, ok
func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ts.Table.Accept(v)
n = newNod.(*TruncateTableStmt)
node, ok := n.Table.Accept(v)
if !ok {
return ts, false
return n, false
}
ts.Table = node.(*TableName)
return v.Leave(ts)
n.Table = node.(*TableName)
return v.Leave(n)
}

View File

@ -22,6 +22,7 @@ var (
_ DMLNode = &DeleteStmt{}
_ DMLNode = &UpdateStmt{}
_ DMLNode = &SelectStmt{}
_ DMLNode = &UnionStmt{}
_ Node = &Join{}
_ Node = &TableName{}
_ Node = &TableSource{}
@ -55,34 +56,36 @@ type Join struct {
// Tp represents join type.
Tp JoinType
// On represents join on condition.
On ExprNode
On *OnCondition
}
// Accept implements Node Accept interface.
func (j *Join) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(j); skipChildren {
return j, ok
func (n *Join) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := j.Left.Accept(v)
n = newNod.(*Join)
node, ok := n.Left.Accept(v)
if !ok {
return j, false
return n, false
}
j.Left = node.(ResultSetNode)
if j.Right != nil {
node, ok = j.Right.Accept(v)
n.Left = node.(ResultSetNode)
if n.Right != nil {
node, ok = n.Right.Accept(v)
if !ok {
return j, false
return n, false
}
j.Right = node.(ResultSetNode)
n.Right = node.(ResultSetNode)
}
if j.On != nil {
node, ok = j.On.Accept(v)
if n.On != nil {
node, ok = n.On.Accept(v)
if !ok {
return j, false
return n, false
}
j.On = node.(ExprNode)
n.On = node.(*OnCondition)
}
return v.Leave(j)
return v.Leave(n)
}
// TableName represents a table name.
@ -98,11 +101,13 @@ type TableName struct {
}
// Accept implements Node Accept interface.
func (tr *TableName) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(tr); skipChildren {
return tr, ok
func (n *TableName) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(tr)
n = newNod.(*TableName)
return v.Leave(n)
}
// TableSource represents table source with a name.
@ -110,7 +115,7 @@ type TableSource struct {
node
// Source is the source of the data, can be a TableName,
// a SubQuery, or a JoinNode.
// a SelectStmt, a UnionStmt, or a JoinNode.
Source ResultSetNode
// AsName is the as name of the table source.
@ -118,47 +123,50 @@ type TableSource struct {
}
// Accept implements Node Accept interface.
func (ts *TableSource) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ts); skipChildren {
return ts, ok
func (n *TableSource) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ts.Source.Accept(v)
n = newNod.(*TableSource)
node, ok := n.Source.Accept(v)
if !ok {
return ts, false
return n, false
}
ts.Source = node.(ResultSetNode)
return v.Leave(ts)
n.Source = node.(ResultSetNode)
return v.Leave(n)
}
// SetResultFields implements ResultSet interface.
func (ts *TableSource) SetResultFields(rfs []*ResultField) {
ts.Source.SetResultFields(rfs)
}
// GetResultFields implements ResultSet interface.
func (ts *TableSource) GetResultFields() []*ResultField {
return ts.Source.GetResultFields()
}
// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause.
type UnionClause struct {
// OnCondition represetns JOIN on condition.
type OnCondition struct {
node
Distinct bool
Select *SelectStmt
Expr ExprNode
}
// Accept implements Node Accept interface.
func (uc *UnionClause) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(uc); skipChildren {
return uc, ok
func (n *OnCondition) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := uc.Select.Accept(v)
n = newNod.(*OnCondition)
node, ok := n.Expr.Accept(v)
if !ok {
return uc, false
return n, false
}
uc.Select = node.(*SelectStmt)
return v.Leave(uc)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// SetResultFields implements ResultSet interface.
func (n *TableSource) SetResultFields(rfs []*ResultField) {
n.Source.SetResultFields(rfs)
}
// GetResultFields implements ResultSet interface.
func (n *TableSource) GetResultFields() []*ResultField {
return n.Source.GetResultFields()
}
// SelectLockType is the lock type for SelectStmt.
@ -175,22 +183,18 @@ const (
type WildCardField struct {
node
Table *TableName
Table model.CIStr
Schema model.CIStr
}
// Accept implements Node Accept interface.
func (wf *WildCardField) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(wf); skipChildren {
return wf, ok
func (n *WildCardField) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if wf.Table != nil {
node, ok := wf.Table.Accept(v)
if !ok {
return wf, false
}
wf.Table = node.(*TableName)
}
return v.Leave(wf)
n = newNod.(*WildCardField)
return v.Leave(n)
}
// SelectField represents fields in select statement.
@ -199,6 +203,8 @@ func (wf *WildCardField) Accept(v Visitor) (Node, bool) {
type SelectField struct {
node
// Offset is used to get original text.
Offset int
// If WildCard is not nil, Expr will be nil.
WildCard *WildCardField
// If Expr is not nil, WildCard will be nil.
@ -208,22 +214,70 @@ type SelectField struct {
}
// Accept implements Node Accept interface.
func (sf *SelectField) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(sf); skipChildren {
return sf, ok
func (n *SelectField) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if sf.Expr != nil {
node, ok := sf.Expr.Accept(v)
n = newNod.(*SelectField)
if n.Expr != nil {
node, ok := n.Expr.Accept(v)
if !ok {
return sf, false
return n, false
}
sf.Expr = node.(ExprNode)
n.Expr = node.(ExprNode)
}
return v.Leave(sf)
return v.Leave(n)
}
// OrderByItem represents a single order by item.
type OrderByItem struct {
// FieldList represents field list in select statement.
type FieldList struct {
node
Fields []*SelectField
}
// Accept implements Node Accept interface.
func (n *FieldList) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*FieldList)
for i, val := range n.Fields {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Fields[i] = node.(*SelectField)
}
return v.Leave(n)
}
// TableRefsClause represents table references clause in dml statement.
type TableRefsClause struct {
node
TableRefs *Join
}
// Accept implements Node Accept interface.
func (n *TableRefsClause) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*TableRefsClause)
node, ok := n.TableRefs.Accept(v)
if !ok {
return n, false
}
n.TableRefs = node.(*Join)
return v.Leave(n)
}
// ByItem represents an item in order by or group by.
type ByItem struct {
node
Expr ExprNode
@ -231,131 +285,244 @@ type OrderByItem struct {
}
// Accept implements Node Accept interface.
func (ob *OrderByItem) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ob); skipChildren {
return ob, ok
func (n *ByItem) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ob.Expr.Accept(v)
n = newNod.(*ByItem)
node, ok := n.Expr.Accept(v)
if !ok {
return ob, false
return n, false
}
ob.Expr = node.(ExprNode)
return v.Leave(ob)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// GroupByClause represents group by clause.
type GroupByClause struct {
node
Items []*ByItem
}
// Accept implements Node Accept interface.
func (n *GroupByClause) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*GroupByClause)
for i, val := range n.Items {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Items[i] = node.(*ByItem)
}
return v.Leave(n)
}
// HavingClause represents having clause.
type HavingClause struct {
node
Expr ExprNode
}
// Accept implements Node Accept interface.
func (n *HavingClause) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*HavingClause)
node, ok := n.Expr.Accept(v)
if !ok {
return n, false
}
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// OrderByClause represents order by clause.
type OrderByClause struct {
node
Items []*ByItem
ForUnion bool
}
// Accept implements Node Accept interface.
func (n *OrderByClause) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*OrderByClause)
for i, val := range n.Items {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Items[i] = node.(*ByItem)
}
return v.Leave(n)
}
// SelectStmt represents the select query node.
// See: https://dev.mysql.com/doc/refman/5.7/en/select.html
type SelectStmt struct {
dmlNode
resultSetNode
// Distinct represents if the select has distinct option.
Distinct bool
// Fields is the select expression list.
Fields []*SelectField
// From is the from clause of the query.
From *Join
From *TableRefsClause
// Where is the where clause in select statement.
Where ExprNode
// Fields is the select expression list.
Fields *FieldList
// GroupBy is the group by expression list.
GroupBy []ExprNode
GroupBy *GroupByClause
// Having is the having condition.
Having ExprNode
// OrderBy is the odering expression list.
OrderBy []*OrderByItem
Having *HavingClause
// OrderBy is the ordering expression list.
OrderBy *OrderByClause
// Limit is the limit clause.
Limit *Limit
// Lock is the lock type
LockTp SelectLockType
// Union clauses.
Unions []*UnionClause
// Order by for union select.
UnionOrderBy []*OrderByItem
// Limit for union select.
UnionLimit *Limit
}
// Accept implements Node Accept interface.
func (sn *SelectStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(sn); skipChildren {
return sn, ok
func (n *SelectStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range sn.Fields {
node, ok := val.Accept(v)
n = newNod.(*SelectStmt)
if n.From != nil {
node, ok := n.From.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.Fields[i] = node.(*SelectField)
}
if sn.From != nil {
node, ok := sn.From.Accept(v)
if !ok {
return sn, false
}
sn.From = node.(*Join)
n.From = node.(*TableRefsClause)
}
if sn.Where != nil {
node, ok := sn.Where.Accept(v)
if n.Where != nil {
node, ok := n.Where.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.Where = node.(ExprNode)
n.Where = node.(ExprNode)
}
for i, val := range sn.GroupBy {
node, ok := val.Accept(v)
if n.Fields != nil {
node, ok := n.Fields.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.GroupBy[i] = node.(ExprNode)
}
if sn.Having != nil {
node, ok := sn.Having.Accept(v)
if !ok {
return sn, false
}
sn.Having = node.(ExprNode)
n.Fields = node.(*FieldList)
}
for i, val := range sn.OrderBy {
node, ok := val.Accept(v)
if n.GroupBy != nil {
node, ok := n.GroupBy.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.OrderBy[i] = node.(*OrderByItem)
n.GroupBy = node.(*GroupByClause)
}
if sn.Limit != nil {
node, ok := sn.Limit.Accept(v)
if n.Having != nil {
node, ok := n.Having.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.Limit = node.(*Limit)
n.Having = node.(*HavingClause)
}
for i, val := range sn.Unions {
if n.OrderBy != nil {
node, ok := n.OrderBy.Accept(v)
if !ok {
return n, false
}
n.OrderBy = node.(*OrderByClause)
}
if n.Limit != nil {
node, ok := n.Limit.Accept(v)
if !ok {
return n, false
}
n.Limit = node.(*Limit)
}
return v.Leave(n)
}
// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause.
type UnionClause struct {
node
Distinct bool
Select *SelectStmt
}
// Accept implements Node Accept interface.
func (n *UnionClause) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*UnionClause)
node, ok := n.Select.Accept(v)
if !ok {
return n, false
}
n.Select = node.(*SelectStmt)
return v.Leave(n)
}
// UnionStmt represents "union statement"
// See: https://dev.mysql.com/doc/refman/5.7/en/union.html
type UnionStmt struct {
dmlNode
resultSetNode
Distinct bool
Selects []*SelectStmt
OrderBy *OrderByClause
Limit *Limit
}
// Accept implements Node Accept interface.
func (n *UnionStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*UnionStmt)
for i, val := range n.Selects {
node, ok := val.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.Unions[i] = node.(*UnionClause)
n.Selects[i] = node.(*SelectStmt)
}
for i, val := range sn.UnionOrderBy {
node, ok := val.Accept(v)
if n.OrderBy != nil {
node, ok := n.OrderBy.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.UnionOrderBy[i] = node.(*OrderByItem)
n.OrderBy = node.(*OrderByClause)
}
if sn.UnionLimit != nil {
node, ok := sn.UnionLimit.Accept(v)
if n.Limit != nil {
node, ok := n.Limit.Accept(v)
if !ok {
return sn, false
return n, false
}
sn.UnionLimit = node.(*Limit)
n.Limit = node.(*Limit)
}
return v.Leave(sn)
return v.Leave(n)
}
// Assignment is the expression for assignment, like a = 1.
@ -368,21 +535,23 @@ type Assignment struct {
}
// Accept implements Node Accept interface.
func (as *Assignment) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(as); skipChildren {
return as, ok
func (n *Assignment) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := as.Column.Accept(v)
n = newNod.(*Assignment)
node, ok := n.Column.Accept(v)
if !ok {
return as, false
return n, false
}
as.Column = node.(*ColumnName)
node, ok = as.Expr.Accept(v)
n.Column = node.(*ColumnName)
node, ok = n.Expr.Accept(v)
if !ok {
return as, false
return n, false
}
as.Expr = node.(ExprNode)
return v.Leave(as)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// Priority const values.
@ -399,58 +568,68 @@ const (
type InsertStmt struct {
dmlNode
Replace bool
Table *TableRefsClause
Columns []*ColumnName
Lists [][]ExprNode
Table *TableName
Setlist []*Assignment
Priority int
OnDuplicate []*Assignment
Select *SelectStmt
Select ResultSetNode
}
// Accept implements Node Accept interface.
func (in *InsertStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(in); skipChildren {
return in, ok
func (n *InsertStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range in.Columns {
n = newNod.(*InsertStmt)
if n.Select != nil {
node, ok := n.Select.Accept(v)
if !ok {
return n, false
}
n.Select = node.(ResultSetNode)
}
node, ok := n.Table.Accept(v)
if !ok {
return n, false
}
n.Table = node.(*TableRefsClause)
for i, val := range n.Columns {
node, ok := val.Accept(v)
if !ok {
return in, false
return n, false
}
in.Columns[i] = node.(*ColumnName)
n.Columns[i] = node.(*ColumnName)
}
for i, list := range in.Lists {
for i, list := range n.Lists {
for j, val := range list {
node, ok := val.Accept(v)
if !ok {
return in, false
return n, false
}
in.Lists[i][j] = node.(ExprNode)
n.Lists[i][j] = node.(ExprNode)
}
}
for i, val := range in.Setlist {
for i, val := range n.Setlist {
node, ok := val.Accept(v)
if !ok {
return in, false
return n, false
}
in.Setlist[i] = node.(*Assignment)
n.Setlist[i] = node.(*Assignment)
}
for i, val := range in.OnDuplicate {
for i, val := range n.OnDuplicate {
node, ok := val.Accept(v)
if !ok {
return in, false
return n, false
}
in.OnDuplicate[i] = node.(*Assignment)
n.OnDuplicate[i] = node.(*Assignment)
}
if in.Select != nil {
node, ok := in.Select.Accept(v)
if !ok {
return in, false
}
in.Select = node.(*SelectStmt)
}
return v.Leave(in)
return v.Leave(n)
}
// DeleteStmt is a statement to delete rows from table.
@ -458,10 +637,12 @@ func (in *InsertStmt) Accept(v Visitor) (Node, bool) {
type DeleteStmt struct {
dmlNode
TableRefs *Join
// Used in both single table and multiple table delete statement.
TableRefs *TableRefsClause
// Only used in multiple table delete statement.
Tables []*TableName
Where ExprNode
Order []*OrderByItem
Order *OrderByClause
Limit *Limit
LowPriority bool
Ignore bool
@ -471,47 +652,49 @@ type DeleteStmt struct {
}
// Accept implements Node Accept interface.
func (de *DeleteStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(de); skipChildren {
return de, ok
func (n *DeleteStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*DeleteStmt)
node, ok := de.TableRefs.Accept(v)
node, ok := n.TableRefs.Accept(v)
if !ok {
return de, false
return n, false
}
de.TableRefs = node.(*Join)
n.TableRefs = node.(*TableRefsClause)
for i, val := range de.Tables {
for i, val := range n.Tables {
node, ok = val.Accept(v)
if !ok {
return de, false
return n, false
}
de.Tables[i] = node.(*TableName)
n.Tables[i] = node.(*TableName)
}
if de.Where != nil {
node, ok = de.Where.Accept(v)
if n.Where != nil {
node, ok = n.Where.Accept(v)
if !ok {
return de, false
return n, false
}
de.Where = node.(ExprNode)
n.Where = node.(ExprNode)
}
for i, val := range de.Order {
node, ok = val.Accept(v)
if n.Order != nil {
node, ok = n.Order.Accept(v)
if !ok {
return de, false
return n, false
}
de.Order[i] = node.(*OrderByItem)
n.Order = node.(*OrderByClause)
}
node, ok = de.Limit.Accept(v)
if !ok {
return de, false
if n.Limit != nil {
node, ok = n.Limit.Accept(v)
if !ok {
return n, false
}
n.Limit = node.(*Limit)
}
de.Limit = node.(*Limit)
return v.Leave(de)
return v.Leave(n)
}
// UpdateStmt is a statement to update columns of existing rows in tables with new values.
@ -519,10 +702,10 @@ func (de *DeleteStmt) Accept(v Visitor) (Node, bool) {
type UpdateStmt struct {
dmlNode
TableRefs *Join
TableRefs *TableRefsClause
List []*Assignment
Where ExprNode
Order []*OrderByItem
Order *OrderByClause
Limit *Limit
LowPriority bool
Ignore bool
@ -530,43 +713,46 @@ type UpdateStmt struct {
}
// Accept implements Node Accept interface.
func (up *UpdateStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(up); skipChildren {
return up, ok
func (n *UpdateStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := up.TableRefs.Accept(v)
n = newNod.(*UpdateStmt)
node, ok := n.TableRefs.Accept(v)
if !ok {
return up, false
return n, false
}
up.TableRefs = node.(*Join)
for i, val := range up.List {
n.TableRefs = node.(*TableRefsClause)
for i, val := range n.List {
node, ok = val.Accept(v)
if !ok {
return up, false
return n, false
}
up.List[i] = node.(*Assignment)
n.List[i] = node.(*Assignment)
}
if up.Where != nil {
node, ok = up.Where.Accept(v)
if n.Where != nil {
node, ok = n.Where.Accept(v)
if !ok {
return up, false
return n, false
}
up.Where = node.(ExprNode)
n.Where = node.(ExprNode)
}
for i, val := range up.Order {
node, ok = val.Accept(v)
if n.Order != nil {
node, ok = n.Order.Accept(v)
if !ok {
return up, false
return n, false
}
up.Order[i] = node.(*OrderByItem)
n.Order = node.(*OrderByClause)
}
node, ok = up.Limit.Accept(v)
if !ok {
return up, false
if n.Limit != nil {
node, ok = n.Limit.Accept(v)
if !ok {
return n, false
}
n.Limit = node.(*Limit)
}
up.Limit = node.(*Limit)
return v.Leave(up)
return v.Leave(n)
}
// Limit is the limit clause.
@ -578,9 +764,11 @@ type Limit struct {
}
// Accept implements Node Accept interface.
func (l *Limit) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(l); skipChildren {
return l, ok
func (n *Limit) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(l)
n = newNod.(*Limit)
return v.Leave(n)
}

View File

@ -14,8 +14,11 @@
package ast
import (
"fmt"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
var (
@ -26,6 +29,7 @@ var (
_ ExprNode = &CaseExpr{}
_ ExprNode = &SubqueryExpr{}
_ ExprNode = &CompareSubqueryExpr{}
_ Node = &ColumnName{}
_ ExprNode = &ColumnNameExpr{}
_ ExprNode = &DefaultExpr{}
_ ExprNode = &IdentifierExpr{}
@ -47,21 +51,58 @@ var (
// ValueExpr is the simple value expression.
type ValueExpr struct {
exprNode
// Val is the literal value.
Val interface{}
}
// NewValueExpr creates a ValueExpr with value, and sets default field type.
func NewValueExpr(value interface{}) *ValueExpr {
ve := &ValueExpr{}
ve.Data = types.RawData(value)
// TODO: make it more precise.
switch value.(type) {
case nil:
ve.Type = types.NewFieldType(mysql.TypeNull)
case bool, int64:
ve.Type = types.NewFieldType(mysql.TypeLonglong)
case uint64:
ve.Type = types.NewFieldType(mysql.TypeLonglong)
ve.Type.Flag |= mysql.UnsignedFlag
case string, UnquoteString:
ve.Type = types.NewFieldType(mysql.TypeVarchar)
ve.Type.Charset = mysql.DefaultCharset
ve.Type.Collate = mysql.DefaultCollationName
case float64:
ve.Type = types.NewFieldType(mysql.TypeDouble)
case []byte:
ve.Type = types.NewFieldType(mysql.TypeBlob)
ve.Type.Charset = "binary"
ve.Type.Collate = "binary"
case mysql.Bit:
ve.Type = types.NewFieldType(mysql.TypeBit)
case mysql.Hex:
ve.Type = types.NewFieldType(mysql.TypeVarchar)
ve.Type.Charset = "binary"
ve.Type.Collate = "binary"
case *types.DataItem:
ve.Type = value.(*types.DataItem).Type
default:
panic(fmt.Sprintf("illegal literal value type:%T", value))
}
return ve
}
// IsStatic implements ExprNode interface.
func (val *ValueExpr) IsStatic() bool {
func (n *ValueExpr) IsStatic() bool {
return true
}
// Accept implements Node interface.
func (val *ValueExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(val); skipChildren {
return val, ok
func (n *ValueExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(val)
n = newNod.(*ValueExpr)
return v.Leave(n)
}
// BetweenExpr is for "between and" or "not between and" expression.
@ -78,35 +119,37 @@ type BetweenExpr struct {
}
// Accept implements Node interface.
func (b *BetweenExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(b); skipChildren {
return b, ok
func (n *BetweenExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*BetweenExpr)
node, ok := b.Expr.Accept(v)
node, ok := n.Expr.Accept(v)
if !ok {
return b, false
return n, false
}
b.Expr = node.(ExprNode)
n.Expr = node.(ExprNode)
node, ok = b.Left.Accept(v)
node, ok = n.Left.Accept(v)
if !ok {
return b, false
return n, false
}
b.Left = node.(ExprNode)
n.Left = node.(ExprNode)
node, ok = b.Right.Accept(v)
node, ok = n.Right.Accept(v)
if !ok {
return b, false
return n, false
}
b.Right = node.(ExprNode)
n.Right = node.(ExprNode)
return v.Leave(b)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (b *BetweenExpr) IsStatic() bool {
return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic()
func (n *BetweenExpr) IsStatic() bool {
return n.Expr.IsStatic() && n.Left.IsStatic() && n.Right.IsStatic()
}
// BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc.
@ -121,29 +164,31 @@ type BinaryOperationExpr struct {
}
// Accept implements Node interface.
func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(o); skipChildren {
return o, ok
func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*BinaryOperationExpr)
node, ok := o.L.Accept(v)
node, ok := n.L.Accept(v)
if !ok {
return o, false
return n, false
}
o.L = node.(ExprNode)
n.L = node.(ExprNode)
node, ok = o.R.Accept(v)
node, ok = n.R.Accept(v)
if !ok {
return o, false
return n, false
}
o.R = node.(ExprNode)
n.R = node.(ExprNode)
return v.Leave(o)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (o *BinaryOperationExpr) IsStatic() bool {
return o.L.IsStatic() && o.R.IsStatic()
func (n *BinaryOperationExpr) IsStatic() bool {
return n.L.IsStatic() && n.R.IsStatic()
}
// WhenClause is the when clause in Case expression for "when condition then result".
@ -156,27 +201,29 @@ type WhenClause struct {
}
// Accept implements Node Accept interface.
func (w *WhenClause) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(w); skipChildren {
return w, ok
func (n *WhenClause) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := w.Expr.Accept(v)
n = newNod.(*WhenClause)
node, ok := n.Expr.Accept(v)
if !ok {
return w, false
return n, false
}
w.Expr = node.(ExprNode)
n.Expr = node.(ExprNode)
node, ok = w.Result.Accept(v)
node, ok = n.Result.Accept(v)
if !ok {
return w, false
return n, false
}
w.Result = node.(ExprNode)
return v.Leave(w)
n.Result = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (w *WhenClause) IsStatic() bool {
return w.Expr.IsStatic() && w.Result.IsStatic()
func (n *WhenClause) IsStatic() bool {
return n.Expr.IsStatic() && n.Result.IsStatic()
}
// CaseExpr is the case expression.
@ -191,45 +238,47 @@ type CaseExpr struct {
}
// Accept implements Node Accept interface.
func (f *CaseExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(f); skipChildren {
return f, ok
func (n *CaseExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if f.Value != nil {
node, ok := f.Value.Accept(v)
n = newNod.(*CaseExpr)
if n.Value != nil {
node, ok := n.Value.Accept(v)
if !ok {
return f, false
return n, false
}
f.Value = node.(ExprNode)
n.Value = node.(ExprNode)
}
for i, val := range f.WhenClauses {
for i, val := range n.WhenClauses {
node, ok := val.Accept(v)
if !ok {
return f, false
return n, false
}
f.WhenClauses[i] = node.(*WhenClause)
n.WhenClauses[i] = node.(*WhenClause)
}
if f.ElseClause != nil {
node, ok := f.ElseClause.Accept(v)
if n.ElseClause != nil {
node, ok := n.ElseClause.Accept(v)
if !ok {
return f, false
return n, false
}
f.ElseClause = node.(ExprNode)
n.ElseClause = node.(ExprNode)
}
return v.Leave(f)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (f *CaseExpr) IsStatic() bool {
if f.Value != nil && !f.Value.IsStatic() {
func (n *CaseExpr) IsStatic() bool {
if n.Value != nil && !n.Value.IsStatic() {
return false
}
for _, w := range f.WhenClauses {
for _, w := range n.WhenClauses {
if !w.IsStatic() {
return false
}
}
if f.ElseClause != nil && !f.ElseClause.IsStatic() {
if n.ElseClause != nil && !n.ElseClause.IsStatic() {
return false
}
return true
@ -239,30 +288,32 @@ func (f *CaseExpr) IsStatic() bool {
type SubqueryExpr struct {
exprNode
// Query is the query SelectNode.
Query *SelectStmt
Query ResultSetNode
}
// Accept implements Node Accept interface.
func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(sq); skipChildren {
return sq, ok
func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := sq.Query.Accept(v)
n = newNod.(*SubqueryExpr)
node, ok := n.Query.Accept(v)
if !ok {
return sq, false
return n, false
}
sq.Query = node.(*SelectStmt)
return v.Leave(sq)
n.Query = node.(ResultSetNode)
return v.Leave(n)
}
// SetResultFields implements ResultSet interface.
func (sq *SubqueryExpr) SetResultFields(rfs []*ResultField) {
sq.Query.SetResultFields(rfs)
func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) {
n.Query.SetResultFields(rfs)
}
// GetResultFields implements ResultSet interface.
func (sq *SubqueryExpr) GetResultFields() []*ResultField {
return sq.Query.GetResultFields()
func (n *SubqueryExpr) GetResultFields() []*ResultField {
return n.Query.GetResultFields()
}
// CompareSubqueryExpr is the expression for "expr cmp (select ...)".
@ -282,21 +333,23 @@ type CompareSubqueryExpr struct {
}
// Accept implements Node Accept interface.
func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cs); skipChildren {
return cs, ok
func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := cs.L.Accept(v)
n = newNod.(*CompareSubqueryExpr)
node, ok := n.L.Accept(v)
if !ok {
return cs, false
return n, false
}
cs.L = node.(ExprNode)
node, ok = cs.R.Accept(v)
n.L = node.(ExprNode)
node, ok = n.R.Accept(v)
if !ok {
return cs, false
return n, false
}
cs.R = node.(*SubqueryExpr)
return v.Leave(cs)
n.R = node.(*SubqueryExpr)
return v.Leave(n)
}
// ColumnName represents column name.
@ -312,11 +365,13 @@ type ColumnName struct {
}
// Accept implements Node Accept interface.
func (cn *ColumnName) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cn); skipChildren {
return cn, ok
func (n *ColumnName) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(cn)
n = newNod.(*ColumnName)
return v.Leave(n)
}
// ColumnNameExpr represents a column name expression.
@ -328,16 +383,18 @@ type ColumnNameExpr struct {
}
// Accept implements Node Accept interface.
func (cr *ColumnNameExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cr); skipChildren {
return cr, ok
func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := cr.Name.Accept(v)
n = newNod.(*ColumnNameExpr)
node, ok := n.Name.Accept(v)
if !ok {
return cr, false
return n, false
}
cr.Name = node.(*ColumnName)
return v.Leave(cr)
n.Name = node.(*ColumnName)
return v.Leave(n)
}
// DefaultExpr is the default expression using default value for a column.
@ -348,18 +405,20 @@ type DefaultExpr struct {
}
// Accept implements Node Accept interface.
func (d *DefaultExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(d); skipChildren {
return d, ok
func (n *DefaultExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if d.Name != nil {
node, ok := d.Name.Accept(v)
n = newNod.(*DefaultExpr)
if n.Name != nil {
node, ok := n.Name.Accept(v)
if !ok {
return d, false
return n, false
}
d.Name = node.(*ColumnName)
n.Name = node.(*ColumnName)
}
return v.Leave(d)
return v.Leave(n)
}
// IdentifierExpr represents an identifier expression.
@ -370,11 +429,13 @@ type IdentifierExpr struct {
}
// Accept implements Node Accept interface.
func (i *IdentifierExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(i); skipChildren {
return i, ok
func (n *IdentifierExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(i)
n = newNod.(*IdentifierExpr)
return v.Leave(n)
}
// ExistsSubqueryExpr is the expression for "exists (select ...)".
@ -386,16 +447,18 @@ type ExistsSubqueryExpr struct {
}
// Accept implements Node Accept interface.
func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(es); skipChildren {
return es, ok
func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := es.Sel.Accept(v)
n = newNod.(*ExistsSubqueryExpr)
node, ok := n.Sel.Accept(v)
if !ok {
return es, false
return n, false
}
es.Sel = node.(*SubqueryExpr)
return v.Leave(es)
n.Sel = node.(*SubqueryExpr)
return v.Leave(n)
}
// PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)".
@ -412,30 +475,32 @@ type PatternInExpr struct {
}
// Accept implements Node Accept interface.
func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(pi); skipChildren {
return pi, ok
func (n *PatternInExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := pi.Expr.Accept(v)
n = newNod.(*PatternInExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return pi, false
return n, false
}
pi.Expr = node.(ExprNode)
for i, val := range pi.List {
n.Expr = node.(ExprNode)
for i, val := range n.List {
node, ok = val.Accept(v)
if !ok {
return pi, false
return n, false
}
pi.List[i] = node.(ExprNode)
n.List[i] = node.(ExprNode)
}
if pi.Sel != nil {
node, ok = pi.Sel.Accept(v)
if n.Sel != nil {
node, ok = n.Sel.Accept(v)
if !ok {
return pi, false
return n, false
}
pi.Sel = node.(*SubqueryExpr)
n.Sel = node.(*SubqueryExpr)
}
return v.Leave(pi)
return v.Leave(n)
}
// IsNullExpr is the expression for null check.
@ -448,21 +513,23 @@ type IsNullExpr struct {
}
// Accept implements Node Accept interface.
func (is *IsNullExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(is); skipChildren {
return is, ok
func (n *IsNullExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := is.Expr.Accept(v)
n = newNod.(*IsNullExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return is, false
return n, false
}
is.Expr = node.(ExprNode)
return v.Leave(is)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (is *IsNullExpr) IsStatic() bool {
return is.Expr.IsStatic()
func (n *IsNullExpr) IsStatic() bool {
return n.Expr.IsStatic()
}
// IsTruthExpr is the expression for true/false check.
@ -477,21 +544,23 @@ type IsTruthExpr struct {
}
// Accept implements Node Accept interface.
func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(is); skipChildren {
return is, ok
func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := is.Expr.Accept(v)
n = newNod.(*IsTruthExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return is, false
return n, false
}
is.Expr = node.(ExprNode)
return v.Leave(is)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (is *IsTruthExpr) IsStatic() bool {
return is.Expr.IsStatic()
func (n *IsTruthExpr) IsStatic() bool {
return n.Expr.IsStatic()
}
// PatternLikeExpr is the expression for like operator, e.g, expr like "%123%"
@ -508,40 +577,49 @@ type PatternLikeExpr struct {
}
// Accept implements Node Accept interface.
func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(pl); skipChildren {
return pl, ok
func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := pl.Expr.Accept(v)
if !ok {
return pl, false
n = newNod.(*PatternLikeExpr)
if n.Expr != nil {
node, ok := n.Expr.Accept(v)
if !ok {
return n, false
}
n.Expr = node.(ExprNode)
}
pl.Expr = node.(ExprNode)
node, ok = pl.Pattern.Accept(v)
if !ok {
return pl, false
if n.Pattern != nil {
node, ok := n.Pattern.Accept(v)
if !ok {
return n, false
}
n.Pattern = node.(ExprNode)
}
pl.Pattern = node.(ExprNode)
return v.Leave(pl)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (pl *PatternLikeExpr) IsStatic() bool {
return pl.Expr.IsStatic() && pl.Pattern.IsStatic()
func (n *PatternLikeExpr) IsStatic() bool {
return n.Expr.IsStatic() && n.Pattern.IsStatic()
}
// ParamMarkerExpr expresion holds a place for another expression.
// Used in parsing prepare statement.
type ParamMarkerExpr struct {
exprNode
Offset int
}
// Accept implements Node Accept interface.
func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(pm); skipChildren {
return pm, ok
func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(pm)
n = newNod.(*ParamMarkerExpr)
return v.Leave(n)
}
// ParenthesesExpr is the parentheses expression.
@ -552,23 +630,25 @@ type ParenthesesExpr struct {
}
// Accept implements Node Accept interface.
func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(p); skipChildren {
return p, ok
func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if p.Expr != nil {
node, ok := p.Expr.Accept(v)
n = newNod.(*ParenthesesExpr)
if n.Expr != nil {
node, ok := n.Expr.Accept(v)
if !ok {
return p, false
return n, false
}
p.Expr = node.(ExprNode)
n.Expr = node.(ExprNode)
}
return v.Leave(p)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (p *ParenthesesExpr) IsStatic() bool {
return p.Expr.IsStatic()
func (n *ParenthesesExpr) IsStatic() bool {
return n.Expr.IsStatic()
}
// PositionExpr is the expression for order by and group by position.
@ -583,16 +663,18 @@ type PositionExpr struct {
}
// IsStatic implements the ExprNode IsStatic interface.
func (p *PositionExpr) IsStatic() bool {
func (n *PositionExpr) IsStatic() bool {
return true
}
// Accept implements Node Accept interface.
func (p *PositionExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(p); skipChildren {
return p, ok
func (n *PositionExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(p)
n = newNod.(*PositionExpr)
return v.Leave(n)
}
// PatternRegexpExpr is the pattern expression for pattern match.
@ -607,26 +689,28 @@ type PatternRegexpExpr struct {
}
// Accept implements Node Accept interface.
func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(p); skipChildren {
return p, ok
func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := p.Expr.Accept(v)
n = newNod.(*PatternRegexpExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return p, false
return n, false
}
p.Expr = node.(ExprNode)
node, ok = p.Pattern.Accept(v)
n.Expr = node.(ExprNode)
node, ok = n.Pattern.Accept(v)
if !ok {
return p, false
return n, false
}
p.Pattern = node.(ExprNode)
return v.Leave(p)
n.Pattern = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (p *PatternRegexpExpr) IsStatic() bool {
return p.Expr.IsStatic() && p.Pattern.IsStatic()
func (n *PatternRegexpExpr) IsStatic() bool {
return n.Expr.IsStatic() && n.Pattern.IsStatic()
}
// RowExpr is the expression for row constructor.
@ -638,23 +722,25 @@ type RowExpr struct {
}
// Accept implements Node Accept interface.
func (r *RowExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(r); skipChildren {
return r, ok
func (n *RowExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range r.Values {
n = newNod.(*RowExpr)
for i, val := range n.Values {
node, ok := val.Accept(v)
if !ok {
return r, false
return n, false
}
r.Values[i] = node.(ExprNode)
n.Values[i] = node.(ExprNode)
}
return v.Leave(r)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (r *RowExpr) IsStatic() bool {
for _, v := range r.Values {
func (n *RowExpr) IsStatic() bool {
for _, v := range n.Values {
if !v.IsStatic() {
return false
}
@ -672,21 +758,23 @@ type UnaryOperationExpr struct {
}
// Accept implements Node Accept interface.
func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(u); skipChildren {
return u, ok
func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := u.V.Accept(v)
n = newNod.(*UnaryOperationExpr)
node, ok := n.V.Accept(v)
if !ok {
return u, false
return n, false
}
u.V = node.(ExprNode)
return v.Leave(u)
n.V = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (u *UnaryOperationExpr) IsStatic() bool {
return u.V.IsStatic()
func (n *UnaryOperationExpr) IsStatic() bool {
return n.V.IsStatic()
}
// ValuesExpr is the expression used in INSERT VALUES
@ -697,16 +785,18 @@ type ValuesExpr struct {
}
// Accept implements Node Accept interface.
func (va *ValuesExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(va); skipChildren {
return va, ok
func (n *ValuesExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := va.Column.Accept(v)
n = newNod.(*ValuesExpr)
node, ok := n.Column.Accept(v)
if !ok {
return va, false
return n, false
}
va.Column = node.(*ColumnName)
return v.Leave(va)
n.Column = node.(*ColumnName)
return v.Leave(n)
}
// VariableExpr is the expression for variable.
@ -721,9 +811,11 @@ type VariableExpr struct {
}
// Accept implements Node Accept interface.
func (va *VariableExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(va); skipChildren {
return va, ok
func (n *VariableExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(va)
n = newNod.(*VariableExpr)
return v.Leave(n)
}

View File

@ -26,44 +26,49 @@ var (
_ FuncNode = &FuncConvertExpr{}
_ FuncNode = &FuncCastExpr{}
_ FuncNode = &FuncSubstringExpr{}
_ FuncNode = &FuncLocateExpr{}
_ FuncNode = &FuncTrimExpr{}
_ FuncNode = &FuncDateArithExpr{}
_ FuncNode = &AggregateFuncExpr{}
)
// UnquoteString is not quoted when printed.
type UnquoteString string
// FuncCallExpr is for function expression.
type FuncCallExpr struct {
funcNode
// F is the function name.
F string
FnName string
// Args is the function args.
Args []ExprNode
// Distinct only affetcts sum, avg, count, group_concat,
// so we can ignore it in other functions
Distinct bool
}
// Accept implements Node interface.
func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(c); skipChildren {
return c, ok
func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range c.Args {
n = newNod.(*FuncCallExpr)
for i, val := range n.Args {
node, ok := val.Accept(v)
if !ok {
return c, false
return n, false
}
c.Args[i] = node.(ExprNode)
n.Args[i] = node.(ExprNode)
}
return v.Leave(c)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (c *FuncCallExpr) IsStatic() bool {
v := builtin.Funcs[strings.ToLower(c.F)]
func (n *FuncCallExpr) IsStatic() bool {
v := builtin.Funcs[strings.ToLower(n.FnName)]
if v.F == nil || !v.IsStatic {
return false
}
for _, v := range c.Args {
for _, v := range n.Args {
if !v.IsStatic() {
return false
}
@ -81,21 +86,23 @@ type FuncExtractExpr struct {
}
// Accept implements Node Accept interface.
func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ex); skipChildren {
return ex, ok
func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ex.Date.Accept(v)
n = newNod.(*FuncExtractExpr)
node, ok := n.Date.Accept(v)
if !ok {
return ex, false
return n, false
}
ex.Date = node.(ExprNode)
return v.Leave(ex)
n.Date = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (ex *FuncExtractExpr) IsStatic() bool {
return ex.Date.IsStatic()
func (n *FuncExtractExpr) IsStatic() bool {
return n.Date.IsStatic()
}
// FuncConvertExpr provides a way to convert data between different character sets.
@ -109,21 +116,23 @@ type FuncConvertExpr struct {
}
// IsStatic implements the ExprNode IsStatic interface.
func (f *FuncConvertExpr) IsStatic() bool {
return f.Expr.IsStatic()
func (n *FuncConvertExpr) IsStatic() bool {
return n.Expr.IsStatic()
}
// Accept implements Node Accept interface.
func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(f); skipChildren {
return f, ok
func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := f.Expr.Accept(v)
n = newNod.(*FuncConvertExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return f, false
return n, false
}
f.Expr = node.(ExprNode)
return v.Leave(f)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// CastFunctionType is the type for cast function.
@ -149,21 +158,23 @@ type FuncCastExpr struct {
}
// IsStatic implements the ExprNode IsStatic interface.
func (f *FuncCastExpr) IsStatic() bool {
return f.Expr.IsStatic()
func (n *FuncCastExpr) IsStatic() bool {
return n.Expr.IsStatic()
}
// Accept implements Node Accept interface.
func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(f); skipChildren {
return f, ok
func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := f.Expr.Accept(v)
n = newNod.(*FuncCastExpr)
node, ok := n.Expr.Accept(v)
if !ok {
return f, false
return n, false
}
f.Expr = node.(ExprNode)
return v.Leave(f)
n.Expr = node.(ExprNode)
return v.Leave(n)
}
// FuncSubstringExpr returns the substring as specified.
@ -177,31 +188,35 @@ type FuncSubstringExpr struct {
}
// Accept implements Node Accept interface.
func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(sf); skipChildren {
return sf, ok
func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := sf.StrExpr.Accept(v)
n = newNod.(*FuncSubstringExpr)
node, ok := n.StrExpr.Accept(v)
if !ok {
return sf, false
return n, false
}
sf.StrExpr = node.(ExprNode)
node, ok = sf.Pos.Accept(v)
n.StrExpr = node.(ExprNode)
node, ok = n.Pos.Accept(v)
if !ok {
return sf, false
return n, false
}
sf.Pos = node.(ExprNode)
node, ok = sf.Len.Accept(v)
if !ok {
return sf, false
n.Pos = node.(ExprNode)
if n.Len != nil {
node, ok = n.Len.Accept(v)
if !ok {
return n, false
}
n.Len = node.(ExprNode)
}
sf.Len = node.(ExprNode)
return v.Leave(sf)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (sf *FuncSubstringExpr) IsStatic() bool {
return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic()
func (n *FuncSubstringExpr) IsStatic() bool {
return n.StrExpr.IsStatic() && n.Pos.IsStatic() && n.Len.IsStatic()
}
// FuncSubstringIndexExpr returns the substring as specified.
@ -215,26 +230,28 @@ type FuncSubstringIndexExpr struct {
}
// Accept implements Node Accept interface.
func (si *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(si); skipChildren {
return si, ok
func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := si.StrExpr.Accept(v)
n = newNod.(*FuncSubstringIndexExpr)
node, ok := n.StrExpr.Accept(v)
if !ok {
return si, false
return n, false
}
si.StrExpr = node.(ExprNode)
node, ok = si.Delim.Accept(v)
n.StrExpr = node.(ExprNode)
node, ok = n.Delim.Accept(v)
if !ok {
return si, false
return n, false
}
si.Delim = node.(ExprNode)
node, ok = si.Count.Accept(v)
n.Delim = node.(ExprNode)
node, ok = n.Count.Accept(v)
if !ok {
return si, false
return n, false
}
si.Count = node.(ExprNode)
return v.Leave(si)
n.Count = node.(ExprNode)
return v.Leave(n)
}
// FuncLocateExpr returns the position of the first occurrence of substring.
@ -248,26 +265,28 @@ type FuncLocateExpr struct {
}
// Accept implements Node Accept interface.
func (le *FuncLocateExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(le); skipChildren {
return le, ok
func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := le.Str.Accept(v)
n = newNod.(*FuncLocateExpr)
node, ok := n.Str.Accept(v)
if !ok {
return le, false
return n, false
}
le.Str = node.(ExprNode)
node, ok = le.SubStr.Accept(v)
n.Str = node.(ExprNode)
node, ok = n.SubStr.Accept(v)
if !ok {
return le, false
return n, false
}
le.SubStr = node.(ExprNode)
node, ok = le.Pos.Accept(v)
n.SubStr = node.(ExprNode)
node, ok = n.Pos.Accept(v)
if !ok {
return le, false
return n, false
}
le.Pos = node.(ExprNode)
return v.Leave(le)
n.Pos = node.(ExprNode)
return v.Leave(n)
}
// TrimDirectionType is the type for trim direction.
@ -295,27 +314,102 @@ type FuncTrimExpr struct {
}
// Accept implements Node Accept interface.
func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(tf); skipChildren {
return tf, ok
func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := tf.Str.Accept(v)
n = newNod.(*FuncTrimExpr)
node, ok := n.Str.Accept(v)
if !ok {
return tf, false
return n, false
}
tf.Str = node.(ExprNode)
node, ok = tf.RemStr.Accept(v)
n.Str = node.(ExprNode)
node, ok = n.RemStr.Accept(v)
if !ok {
return tf, false
return n, false
}
tf.RemStr = node.(ExprNode)
return v.Leave(tf)
n.RemStr = node.(ExprNode)
return v.Leave(n)
}
// IsStatic implements the ExprNode IsStatic interface.
func (tf *FuncTrimExpr) IsStatic() bool {
return tf.Str.IsStatic() && tf.RemStr.IsStatic()
func (n *FuncTrimExpr) IsStatic() bool {
return n.Str.IsStatic() && n.RemStr.IsStatic()
}
// TypeStar is a special type for "*".
type TypeStar string
// DateArithType is type for DateArith option.
type DateArithType byte
const (
// DateAdd is to run date_add function option.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
DateAdd DateArithType = iota + 1
// DateSub is to run date_sub function option.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
DateSub
)
// FuncDateArithExpr is the struct for date arithmetic functions.
type FuncDateArithExpr struct {
funcNode
Op DateArithType
Unit string
Date ExprNode
Interval ExprNode
}
// Accept implements Node Accept interface.
func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*FuncDateArithExpr)
if n.Date != nil {
node, ok := n.Date.Accept(v)
if !ok {
return n, false
}
n.Date = node.(ExprNode)
}
if n.Interval != nil {
node, ok := n.Date.Accept(v)
if !ok {
return n, false
}
n.Date = node.(ExprNode)
}
return v.Leave(n)
}
// AggregateFuncExpr represents aggregate function expression.
type AggregateFuncExpr struct {
funcNode
// F is the function name.
F string
// Args is the function args.
Args []ExprNode
// If distinct is true, the function only aggregate distinct values.
// For example, column c1 values are "1", "2", "2", "sum(c1)" is "5",
// but "sum(distinct c1)" is "3".
Distinct bool
}
// Accept implements Node Accept interface.
func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*AggregateFuncExpr)
for i, val := range n.Args {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Args[i] = node.(ExprNode)
}
return v.Leave(n)
}

View File

@ -13,6 +13,8 @@
package ast
import "github.com/pingcap/tidb/mysql"
var (
_ StmtNode = &ExplainStmt{}
_ StmtNode = &PrepareStmt{}
@ -26,7 +28,9 @@ var (
_ StmtNode = &SetStmt{}
_ StmtNode = &SetCharsetStmt{}
_ StmtNode = &SetPwdStmt{}
_ StmtNode = &CreateUserStmt{}
_ StmtNode = &DoStmt{}
_ StmtNode = &GrantStmt{}
_ Node = &VariableAssignment{}
)
@ -57,16 +61,18 @@ type ExplainStmt struct {
}
// Accept implements Node Accept interface.
func (es *ExplainStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(es); skipChildren {
return es, ok
func (n *ExplainStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := es.Stmt.Accept(v)
n = newNod.(*ExplainStmt)
node, ok := n.Stmt.Accept(v)
if !ok {
return es, false
return n, false
}
es.Stmt = node.(DMLNode)
return v.Leave(es)
n.Stmt = node.(DMLNode)
return v.Leave(n)
}
// PrepareStmt is a statement to prepares a SQL statement which contains placeholders,
@ -83,16 +89,18 @@ type PrepareStmt struct {
}
// Accept implements Node Accept interface.
func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ps); skipChildren {
return ps, ok
func (n *PrepareStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := ps.SQLVar.Accept(v)
n = newNod.(*PrepareStmt)
node, ok := n.SQLVar.Accept(v)
if !ok {
return ps, false
return n, false
}
ps.SQLVar = node.(*VariableExpr)
return v.Leave(ps)
n.SQLVar = node.(*VariableExpr)
return v.Leave(n)
}
// DeallocateStmt is a statement to release PreparedStmt.
@ -105,11 +113,13 @@ type DeallocateStmt struct {
}
// Accept implements Node Accept interface.
func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ds); skipChildren {
return ds, ok
func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(ds)
n = newNod.(*DeallocateStmt)
return v.Leave(n)
}
// ExecuteStmt is a statement to execute PreparedStmt.
@ -123,18 +133,20 @@ type ExecuteStmt struct {
}
// Accept implements Node Accept interface.
func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(es); skipChildren {
return es, ok
func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range es.UsingVars {
n = newNod.(*ExecuteStmt)
for i, val := range n.UsingVars {
node, ok := val.Accept(v)
if !ok {
return es, false
return n, false
}
es.UsingVars[i] = node.(ExprNode)
n.UsingVars[i] = node.(ExprNode)
}
return v.Leave(es)
return v.Leave(n)
}
// ShowStmtType is the type for SHOW statement.
@ -152,12 +164,13 @@ const (
ShowVariables
ShowCollation
ShowCreateTable
ShowGrants
)
// ShowStmt is a statement to provide information about databases, tables, columns and so on.
// See: https://dev.mysql.com/doc/refman/5.7/en/show.html
type ShowStmt struct {
stmtNode
dmlNode
Tp ShowStmtType // Databases/Tables/Columns/....
DBName string
@ -165,6 +178,7 @@ type ShowStmt struct {
Column *ColumnName // Used for `desc table column`.
Flag int // Some flag parsed from sql, such as FULL.
Full bool
User string // Used for show grants.
// Used by show variables
GlobalScope bool
@ -173,39 +187,41 @@ type ShowStmt struct {
}
// Accept implements Node Accept interface.
func (ss *ShowStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(ss); skipChildren {
return ss, ok
func (n *ShowStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
if ss.Table != nil {
node, ok := ss.Table.Accept(v)
n = newNod.(*ShowStmt)
if n.Table != nil {
node, ok := n.Table.Accept(v)
if !ok {
return ss, false
return n, false
}
ss.Table = node.(*TableName)
n.Table = node.(*TableName)
}
if ss.Column != nil {
node, ok := ss.Column.Accept(v)
if n.Column != nil {
node, ok := n.Column.Accept(v)
if !ok {
return ss, false
return n, false
}
ss.Column = node.(*ColumnName)
n.Column = node.(*ColumnName)
}
if ss.Pattern != nil {
node, ok := ss.Pattern.Accept(v)
if n.Pattern != nil {
node, ok := n.Pattern.Accept(v)
if !ok {
return ss, false
return n, false
}
ss.Pattern = node.(*PatternLikeExpr)
n.Pattern = node.(*PatternLikeExpr)
}
if ss.Where != nil {
node, ok := ss.Where.Accept(v)
if n.Where != nil {
node, ok := n.Where.Accept(v)
if !ok {
return ss, false
return n, false
}
ss.Where = node.(ExprNode)
n.Where = node.(ExprNode)
}
return v.Leave(ss)
return v.Leave(n)
}
// BeginStmt is a statement to start a new transaction.
@ -215,11 +231,13 @@ type BeginStmt struct {
}
// Accept implements Node Accept interface.
func (bs *BeginStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(bs); skipChildren {
return bs, ok
func (n *BeginStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(bs)
n = newNod.(*BeginStmt)
return v.Leave(n)
}
// CommitStmt is a statement to commit the current transaction.
@ -229,11 +247,13 @@ type CommitStmt struct {
}
// Accept implements Node Accept interface.
func (cs *CommitStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(cs); skipChildren {
return cs, ok
func (n *CommitStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(cs)
n = newNod.(*CommitStmt)
return v.Leave(n)
}
// RollbackStmt is a statement to roll back the current transaction.
@ -243,11 +263,13 @@ type RollbackStmt struct {
}
// Accept implements Node Accept interface.
func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(rs); skipChildren {
return rs, ok
func (n *RollbackStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(rs)
n = newNod.(*RollbackStmt)
return v.Leave(n)
}
// UseStmt is a statement to use the DBName database as the current database.
@ -259,11 +281,13 @@ type UseStmt struct {
}
// Accept implements Node Accept interface.
func (us *UseStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(us); skipChildren {
return us, ok
func (n *UseStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(us)
n = newNod.(*UseStmt)
return v.Leave(n)
}
// VariableAssignment is a variable assignment struct.
@ -276,16 +300,18 @@ type VariableAssignment struct {
}
// Accept implements Node interface.
func (va *VariableAssignment) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(va); skipChildren {
return va, ok
func (n *VariableAssignment) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
node, ok := va.Value.Accept(v)
n = newNod.(*VariableAssignment)
node, ok := n.Value.Accept(v)
if !ok {
return va, false
return n, false
}
va.Value = node.(ExprNode)
return v.Leave(va)
n.Value = node.(ExprNode)
return v.Leave(n)
}
// SetStmt is the statement to set variables.
@ -296,18 +322,20 @@ type SetStmt struct {
}
// Accept implements Node Accept interface.
func (set *SetStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(set); skipChildren {
return set, ok
func (n *SetStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range set.Variables {
n = newNod.(*SetStmt)
for i, val := range n.Variables {
node, ok := val.Accept(v)
if !ok {
return set, false
return n, false
}
set.Variables[i] = node.(*VariableAssignment)
n.Variables[i] = node.(*VariableAssignment)
}
return v.Leave(set)
return v.Leave(n)
}
// SetCharsetStmt is a statement to assign values to character and collation variables.
@ -320,11 +348,13 @@ type SetCharsetStmt struct {
}
// Accept implements Node Accept interface.
func (set *SetCharsetStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(set); skipChildren {
return set, ok
func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(set)
n = newNod.(*SetCharsetStmt)
return v.Leave(n)
}
// SetPwdStmt is a statement to assign a password to user account.
@ -337,11 +367,13 @@ type SetPwdStmt struct {
}
// Accept implements Node Accept interface.
func (set *SetPwdStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(set); skipChildren {
return set, ok
func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
return v.Leave(set)
n = newNod.(*SetPwdStmt)
return v.Leave(n)
}
// UserSpec is used for parsing create user statement.
@ -353,10 +385,20 @@ type UserSpec struct {
// CreateUserStmt creates user account.
// See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html
type CreateUserStmt struct {
stmtNode
IfNotExists bool
Specs []*UserSpec
}
Text string
// Accept implements Node Accept interface.
func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*CreateUserStmt)
return v.Leave(n)
}
// DoStmt is the struct for DO statement.
@ -367,16 +409,100 @@ type DoStmt struct {
}
// Accept implements Node Accept interface.
func (do *DoStmt) Accept(v Visitor) (Node, bool) {
if skipChildren, ok := v.Enter(do); skipChildren {
return do, ok
func (n *DoStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
for i, val := range do.Exprs {
n = newNod.(*DoStmt)
for i, val := range n.Exprs {
node, ok := val.Accept(v)
if !ok {
return do, false
return n, false
}
do.Exprs[i] = node.(ExprNode)
n.Exprs[i] = node.(ExprNode)
}
return v.Leave(do)
return v.Leave(n)
}
// PrivElem is the privilege type and optional column list.
type PrivElem struct {
node
Priv mysql.PrivilegeType
Cols []*ColumnName
}
// Accept implements Node Accept interface.
func (n *PrivElem) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*PrivElem)
for i, val := range n.Cols {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Cols[i] = node.(*ColumnName)
}
return v.Leave(n)
}
// ObjectTypeType is the type for object type.
type ObjectTypeType int
const (
// ObjectTypeNone is for empty object type.
ObjectTypeNone ObjectTypeType = iota
// ObjectTypeTable means the following object is a table.
ObjectTypeTable
)
// GrantLevelType is the type for grant level.
type GrantLevelType int
const (
// GrantLevelNone is the dummy const for default value.
GrantLevelNone GrantLevelType = iota
// GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server.
GrantLevelGlobal
// GrantLevelDB means the privileges apply to all objects in a given database.
GrantLevelDB
// GrantLevelTable means the privileges apply to all columns in a given table.
GrantLevelTable
)
// GrantLevel is used for store the privilege scope.
type GrantLevel struct {
Level GrantLevelType
DBName string
TableName string
}
// GrantStmt is the struct for GRANT statement.
type GrantStmt struct {
stmtNode
Privs []*PrivElem
ObjectType ObjectTypeType
Level *GrantLevel
Users []*UserSpec
}
// Accept implements Node Accept interface.
func (n *GrantStmt) Accept(v Visitor) (Node, bool) {
newNod, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNod)
}
n = newNod.(*GrantStmt)
for i, val := range n.Privs {
node, ok := val.Accept(v)
if !ok {
return n, false
}
n.Privs[i] = node.(*PrivElem)
}
return v.Leave(n)
}

File diff suppressed because it is too large Load Diff

View File

@ -1,79 +0,0 @@
package parser
import (
"fmt"
"testing"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
)
func TestT(t *testing.T) {
TestingT(t)
}
var _ = Suite(&testParserSuite{})
var _ = Suite(&testParserSuite{})
type testParserSuite struct {
}
func (s *testParserSuite) TestSimple(c *C) {
// Testcase for unreserved keywords
unreservedKws := []string{
"auto_increment", "after", "begin", "bit", "bool", "boolean", "charset", "columns", "commit",
"date", "datetime", "deallocate", "do", "end", "engine", "engines", "execute", "first", "full",
"local", "names", "offset", "password", "prepare", "quick", "rollback", "session", "signed",
"start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown",
"value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified",
"collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size",
"max_rows", "min_rows", "national", "row", "quarter", "escape",
}
for _, kw := range unreservedKws {
src := fmt.Sprintf("SELECT %s FROM tbl;", kw)
l := NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
c.Assert(l.errs, HasLen, 0, Commentf("source %s", src))
}
// Testcase for prepared statement
src := "SELECT id+?, id+? from t;"
l := NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
c.Assert(len(l.Stmts()), Equals, 1)
// Testcase for -- Comment and unary -- operator
src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;"
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
c.Assert(len(l.Stmts()), Equals, 2)
// Testcase for CONVERT(expr,type)
src = "SELECT CONVERT('111', SIGNED);"
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
st := l.Stmts()[0]
ss, ok := st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
c.Assert(len(ss.Fields), Equals, 1)
cv, ok := ss.Fields[0].Expr.(*ast.FuncCastExpr)
c.Assert(ok, IsTrue)
c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction)
// For query start with comment
srcs := []string{
"/* some comments */ SELECT CONVERT('111', SIGNED) ;",
"/* some comments */ /*comment*/ SELECT CONVERT('111', SIGNED) ;",
"SELECT /*comment*/ CONVERT('111', SIGNED) ;",
"SELECT CONVERT('111', /*comment*/ SIGNED) ;",
"SELECT CONVERT('111', SIGNED) /*comment*/;",
}
for _, src := range srcs {
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
st = l.Stmts()[0]
ss, ok = st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
}
}

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,7 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/optimizer"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/stmt"
@ -197,5 +198,7 @@ func statement(sql string) stmt.Statement {
log.Debug("Compile", sql)
lexer := parser.NewLexer(sql)
parser.YYParse(lexer)
return lexer.Stmts()[0].(stmt.Statement)
compiler := &optimizer.Compiler{}
stm, _ := compiler.Compile(lexer.Stmts()[0])
return stm
}

View File

@ -18,17 +18,6 @@ DEL /F /A /Q y.output
golex -o parser/scanner.go parser/scanner.l
@echo [Ast-Parser]
go get github.com/qiuyesuifeng/goyacc
go get github.com/qiuyesuifeng/golex
type nul >>temp.XXXXXX
goyacc -o nul -xegen "temp.XXXXXX" parser/parser.y
goyacc -o ast/parser/parser.go -xe "temp.XXXXXX" ast/parser/parser.y
DEL /F /A /Q temp.XXXXXX
DEL /F /A /Q y.output
golex -o ast/parser/scanner.go ast/parser/scanner.l
@echo [Build]
godep go build -ldflags '%LDFLAGS%'

26
optimizer/aggregator.go Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
// Aggregator is the interface to
// compute aggregate function result.
type Aggregator interface {
// Input adds an input value to aggregator.
// The input values are accumulated in the aggregator.
Input(in ...interface{}) error
// Output uses input values to compute the aggregated result.
Output() interface{}
// Clear clears the input values.
Clear()
}

132
optimizer/compiler.go Normal file
View File

@ -0,0 +1,132 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"sort"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/stmt"
)
// Compiler compiles ast.Node into an executable statement.
type Compiler struct {
converter *expressionConverter
}
// Compile compiles a ast.Node into an executable statement.
func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) {
validator := &validator{}
if _, ok := node.Accept(validator); !ok {
return nil, errors.Trace(validator.err)
}
// binder := &InfoBinder{}
// if _, ok := node.Accept(validator); !ok {
// return nil, errors.Trace(binder.Err)
// }
tpComputer := &typeComputer{}
if _, ok := node.Accept(tpComputer); !ok {
return nil, errors.Trace(tpComputer.err)
}
c := newExpressionConverter()
com.converter = c
switch v := node.(type) {
case *ast.InsertStmt:
return convertInsert(c, v)
case *ast.DeleteStmt:
return convertDelete(c, v)
case *ast.UpdateStmt:
return convertUpdate(c, v)
case *ast.SelectStmt:
return convertSelect(c, v)
case *ast.UnionStmt:
return convertUnion(c, v)
case *ast.CreateDatabaseStmt:
return convertCreateDatabase(c, v)
case *ast.DropDatabaseStmt:
return convertDropDatabase(c, v)
case *ast.CreateTableStmt:
return convertCreateTable(c, v)
case *ast.DropTableStmt:
return convertDropTable(c, v)
case *ast.CreateIndexStmt:
return convertCreateIndex(c, v)
case *ast.DropIndexStmt:
return convertDropIndex(c, v)
case *ast.AlterTableStmt:
return convertAlterTable(c, v)
case *ast.TruncateTableStmt:
return convertTruncateTable(c, v)
case *ast.ExplainStmt:
return convertExplain(c, v)
case *ast.PrepareStmt:
return convertPrepare(c, v)
case *ast.DeallocateStmt:
return convertDeallocate(c, v)
case *ast.ExecuteStmt:
return convertExecute(c, v)
case *ast.ShowStmt:
return convertShow(c, v)
case *ast.BeginStmt:
return convertBegin(c, v)
case *ast.CommitStmt:
return convertCommit(c, v)
case *ast.RollbackStmt:
return convertRollback(c, v)
case *ast.UseStmt:
return convertUse(c, v)
case *ast.SetStmt:
return convertSet(c, v)
case *ast.SetCharsetStmt:
return convertSetCharset(c, v)
case *ast.SetPwdStmt:
return convertSetPwd(c, v)
case *ast.CreateUserStmt:
return convertCreateUser(c, v)
case *ast.DoStmt:
return convertDo(c, v)
case *ast.GrantStmt:
return convertGrant(c, v)
}
return nil, nil
}
type paramMarkers []*ast.ParamMarkerExpr
func (p paramMarkers) Len() int {
return len(p)
}
func (p paramMarkers) Less(i, j int) bool {
return p[i].Offset < p[j].Offset
}
func (p paramMarkers) Swap(i, j int) {
p[i], p[j] = p[j], p[i]
}
// ParamMarkers returns parameter markers for prepared statement.
func (com *Compiler) ParamMarkers() []*expression.ParamMarker {
c := com.converter
sort.Sort(c.paramMarkers)
oldMarkers := make([]*expression.ParamMarker, len(c.paramMarkers))
for i, val := range c.paramMarkers {
oldMarkers[i] = c.exprMap[val].(*expression.ParamMarker)
}
return oldMarkers
}

425
optimizer/convert_expr.go Normal file
View File

@ -0,0 +1,425 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/subquery"
"github.com/pingcap/tidb/model"
"strings"
)
func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression.Expression, error) {
expr.Accept(converter)
if converter.err != nil {
return nil, errors.Trace(converter.err)
}
return converter.exprMap[expr], nil
}
// expressionConverter converts ast expression to
// old expression for transition state.
type expressionConverter struct {
exprMap map[ast.Node]expression.Expression
paramMarkers paramMarkers
err error
}
func newExpressionConverter() *expressionConverter {
return &expressionConverter{
exprMap: map[ast.Node]expression.Expression{},
}
}
// Enter implements ast.Visitor interface.
func (c *expressionConverter) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
return in, false
}
// Leave implements ast.Visitor interface.
func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) {
switch v := in.(type) {
case *ast.ValueExpr:
c.value(v)
case *ast.BetweenExpr:
c.between(v)
case *ast.BinaryOperationExpr:
c.binaryOperation(v)
case *ast.WhenClause:
c.whenClause(v)
case *ast.CaseExpr:
c.caseExpr(v)
case *ast.SubqueryExpr:
c.subquery(v)
case *ast.CompareSubqueryExpr:
c.compareSubquery(v)
case *ast.ColumnNameExpr:
c.columnNameExpr(v)
case *ast.DefaultExpr:
c.defaultExpr(v)
case *ast.IdentifierExpr:
c.identifier(v)
case *ast.ExistsSubqueryExpr:
c.existsSubquery(v)
case *ast.PatternInExpr:
c.patternIn(v)
case *ast.IsNullExpr:
c.isNull(v)
case *ast.IsTruthExpr:
c.isTruth(v)
case *ast.PatternLikeExpr:
c.patternLike(v)
case *ast.ParamMarkerExpr:
c.paramMarker(v)
case *ast.ParenthesesExpr:
c.parentheses(v)
case *ast.PositionExpr:
c.position(v)
case *ast.PatternRegexpExpr:
c.patternRegexp(v)
case *ast.RowExpr:
c.row(v)
case *ast.UnaryOperationExpr:
c.unaryOperation(v)
case *ast.ValuesExpr:
c.values(v)
case *ast.VariableExpr:
c.variable(v)
case *ast.FuncCallExpr:
c.funcCall(v)
case *ast.FuncExtractExpr:
c.funcExtract(v)
case *ast.FuncConvertExpr:
c.funcConvert(v)
case *ast.FuncCastExpr:
c.funcCast(v)
case *ast.FuncSubstringExpr:
c.funcSubstring(v)
case *ast.FuncLocateExpr:
c.funcLocate(v)
case *ast.FuncTrimExpr:
c.funcTrim(v)
case *ast.FuncDateArithExpr:
c.funcDateArith(v)
case *ast.AggregateFuncExpr:
c.aggregateFunc(v)
}
return in, c.err == nil
}
func (c *expressionConverter) value(v *ast.ValueExpr) {
c.exprMap[v] = expression.Value{Val: v.GetValue()}
}
func (c *expressionConverter) between(v *ast.BetweenExpr) {
oldExpr := c.exprMap[v.Expr]
oldLo := c.exprMap[v.Left]
oldHi := c.exprMap[v.Right]
oldBetween, err := expression.NewBetween(oldExpr, oldLo, oldHi, v.Not)
if err != nil {
c.err = err
return
}
c.exprMap[v] = oldBetween
}
func (c *expressionConverter) binaryOperation(v *ast.BinaryOperationExpr) {
oldeLeft := c.exprMap[v.L]
oldRight := c.exprMap[v.R]
oldBinop := expression.NewBinaryOperation(v.Op, oldeLeft, oldRight)
c.exprMap[v] = oldBinop
}
func (c *expressionConverter) whenClause(v *ast.WhenClause) {
oldExpr := c.exprMap[v.Expr]
oldResult := c.exprMap[v.Result]
oldWhenClause := &expression.WhenClause{Expr: oldExpr, Result: oldResult}
c.exprMap[v] = oldWhenClause
}
func (c *expressionConverter) caseExpr(v *ast.CaseExpr) {
oldValue := c.exprMap[v.Value]
oldWhenClauses := make([]*expression.WhenClause, len(v.WhenClauses))
for i, val := range v.WhenClauses {
oldWhenClauses[i] = c.exprMap[val].(*expression.WhenClause)
}
oldElse := c.exprMap[v.ElseClause]
oldCaseExpr := &expression.FunctionCase{
Value: oldValue,
WhenClauses: oldWhenClauses,
ElseClause: oldElse,
}
c.exprMap[v] = oldCaseExpr
}
func (c *expressionConverter) subquery(v *ast.SubqueryExpr) {
oldSubquery := &subquery.SubQuery{}
switch x := v.Query.(type) {
case *ast.SelectStmt:
oldSelect, err := convertSelect(c, x)
if err != nil {
c.err = err
return
}
oldSubquery.Stmt = oldSelect
case *ast.UnionStmt:
oldUnion, err := convertUnion(c, x)
if err != nil {
c.err = err
return
}
oldSubquery.Stmt = oldUnion
}
c.exprMap[v] = oldSubquery
}
func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) {
expr := c.exprMap[v.L]
subquery := c.exprMap[v.R]
oldCmpSubquery := expression.NewCompareSubQuery(v.Op, expr, subquery.(expression.SubQuery), v.All)
c.exprMap[v] = oldCmpSubquery
}
func joinColumnName(columnName *ast.ColumnName) string {
var originStrs []string
if columnName.Schema.O != "" {
originStrs = append(originStrs, columnName.Schema.O)
}
if columnName.Table.O != "" {
originStrs = append(originStrs, columnName.Table.O)
}
originStrs = append(originStrs, columnName.Name.O)
return strings.Join(originStrs, ".")
}
func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) {
ident := &expression.Ident{}
ident.CIStr = model.NewCIStr(joinColumnName(v.Name))
c.exprMap[v] = ident
}
func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) {
oldDefault := &expression.Default{}
if v.Name != nil {
oldDefault.Name = joinColumnName(v.Name)
}
c.exprMap[v] = oldDefault
}
func (c *expressionConverter) identifier(v *ast.IdentifierExpr) {
oldIdent := &expression.Ident{}
oldIdent.CIStr = v.Name
c.exprMap[v] = oldIdent
}
func (c *expressionConverter) existsSubquery(v *ast.ExistsSubqueryExpr) {
subquery := c.exprMap[v.Sel].(expression.SubQuery)
c.exprMap[v] = &expression.ExistsSubQuery{Sel: subquery}
}
func (c *expressionConverter) patternIn(v *ast.PatternInExpr) {
oldPatternIn := &expression.PatternIn{Not: v.Not}
if v.Sel != nil {
oldPatternIn.Sel = c.exprMap[v.Sel].(expression.SubQuery)
}
oldPatternIn.Expr = c.exprMap[v.Expr]
if v.List != nil {
oldPatternIn.List = make([]expression.Expression, len(v.List))
for i, v := range v.List {
oldPatternIn.List[i] = c.exprMap[v]
}
}
c.exprMap[v] = oldPatternIn
}
func (c *expressionConverter) isNull(v *ast.IsNullExpr) {
oldIsNull := &expression.IsNull{Not: v.Not}
oldIsNull.Expr = c.exprMap[v.Expr]
c.exprMap[v] = oldIsNull
}
func (c *expressionConverter) isTruth(v *ast.IsTruthExpr) {
oldIsTruth := &expression.IsTruth{Not: v.Not, True: v.True}
oldIsTruth.Expr = c.exprMap[v.Expr]
c.exprMap[v] = oldIsTruth
}
func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) {
oldPatternLike := &expression.PatternLike{
Not: v.Not,
Escape: v.Escape,
Expr: c.exprMap[v.Expr],
Pattern: c.exprMap[v.Pattern],
}
c.exprMap[v] = oldPatternLike
}
func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) {
if c.exprMap[v] == nil {
c.exprMap[v] = &expression.ParamMarker{}
c.paramMarkers = append(c.paramMarkers, v)
}
}
func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) {
oldExpr := c.exprMap[v.Expr]
c.exprMap[v] = &expression.PExpr{Expr: oldExpr}
}
func (c *expressionConverter) position(v *ast.PositionExpr) {
c.exprMap[v] = &expression.Position{N: v.N, Name: v.Name}
}
func (c *expressionConverter) patternRegexp(v *ast.PatternRegexpExpr) {
oldPatternRegexp := &expression.PatternRegexp{
Not: v.Not,
Expr: c.exprMap[v.Expr],
Pattern: c.exprMap[v.Pattern],
}
c.exprMap[v] = oldPatternRegexp
}
func (c *expressionConverter) row(v *ast.RowExpr) {
oldRow := &expression.Row{}
oldRow.Values = make([]expression.Expression, len(v.Values))
for i, val := range v.Values {
oldRow.Values[i] = c.exprMap[val]
}
c.exprMap[v] = oldRow
}
func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) {
oldUnary := &expression.UnaryOperation{
Op: v.Op,
V: c.exprMap[v.V],
}
c.exprMap[v] = oldUnary
}
func (c *expressionConverter) values(v *ast.ValuesExpr) {
nameStr := joinColumnName(v.Column)
c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)}
}
func (c *expressionConverter) variable(v *ast.VariableExpr) {
c.exprMap[v] = &expression.Variable{
IsGlobal: v.IsGlobal,
IsSystem: v.IsSystem,
Name: v.Name,
}
}
func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) {
oldCall := &expression.Call{
F: v.FnName,
}
oldCall.Args = make([]expression.Expression, len(v.Args))
for i, val := range v.Args {
oldCall.Args[i] = c.exprMap[val]
}
c.exprMap[v] = oldCall
}
func (c *expressionConverter) funcExtract(v *ast.FuncExtractExpr) {
oldExtract := &expression.Extract{Unit: v.Unit}
oldExtract.Date = c.exprMap[v.Date]
c.exprMap[v] = oldExtract
}
func (c *expressionConverter) funcConvert(v *ast.FuncConvertExpr) {
c.exprMap[v] = &expression.FunctionConvert{
Expr: c.exprMap[v.Expr],
Charset: v.Charset,
}
}
func (c *expressionConverter) funcCast(v *ast.FuncCastExpr) {
oldCast := &expression.FunctionCast{
Expr: c.exprMap[v.Expr],
Tp: v.Tp,
}
switch v.FunctionType {
case ast.CastBinaryOperator:
oldCast.FunctionType = expression.BinaryOperator
case ast.CastConvertFunction:
oldCast.FunctionType = expression.ConvertFunction
case ast.CastFunction:
oldCast.FunctionType = expression.CastFunction
}
c.exprMap[v] = oldCast
}
func (c *expressionConverter) funcSubstring(v *ast.FuncSubstringExpr) {
oldSubstring := &expression.FunctionSubstring{
Len: c.exprMap[v.Len],
Pos: c.exprMap[v.Pos],
StrExpr: c.exprMap[v.StrExpr],
}
c.exprMap[v] = oldSubstring
}
func (c *expressionConverter) funcLocate(v *ast.FuncLocateExpr) {
oldLocate := &expression.FunctionLocate{
Pos: c.exprMap[v.Pos],
Str: c.exprMap[v.Str],
SubStr: c.exprMap[v.SubStr],
}
c.exprMap[v] = oldLocate
}
func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) {
oldTrim := &expression.FunctionTrim{
Str: c.exprMap[v.Str],
RemStr: c.exprMap[v.RemStr],
}
switch v.Direction {
case ast.TrimBoth:
oldTrim.Direction = expression.TrimBoth
case ast.TrimBothDefault:
oldTrim.Direction = expression.TrimBothDefault
case ast.TrimLeading:
oldTrim.Direction = expression.TrimLeading
case ast.TrimTrailing:
oldTrim.Direction = expression.TrimTrailing
}
c.exprMap[v] = oldTrim
}
func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) {
oldDateArith := &expression.DateArith{
Unit: v.Unit,
Date: c.exprMap[v.Date],
Interval: c.exprMap[v.Interval],
}
switch v.Op {
case ast.DateAdd:
oldDateArith.Op = expression.DateAdd
case ast.DateSub:
oldDateArith.Op = expression.DateSub
}
c.exprMap[v] = oldDateArith
}
func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) {
oldAggregate := &expression.Call{
F: v.F,
Distinct: v.Distinct,
}
for _, val := range v.Args {
oldAggregate.Args = append(oldAggregate.Args, c.exprMap[val])
}
c.exprMap[v] = oldAggregate
}

1024
optimizer/convert_stmt.go Normal file

File diff suppressed because it is too large Load Diff

317
optimizer/evaluator.go Normal file
View File

@ -0,0 +1,317 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/plan"
)
// Evaluator is a ast visitor that evaluates an expression.
type Evaluator struct {
// columnMap is the map from ColumnName to the position of the rowStack.
// It is used to find the value of the column.
columnMap map[*ast.ColumnName]position
// rowStack is the current row values while scanning.
// It should be updated after scaned a new row.
rowStack []*plan.Row
// the map from AggregateFuncExpr to aggregator index.
aggregateMap map[*ast.AggregateFuncExpr]int
// aggregators for the current row
// only outer query aggregate functions are handled.
aggregators []Aggregator
// when aggregation phase is done, the input is
aggregateDone bool
err error
}
type position struct {
stackOffset int
fieldList bool
columnOffset int
}
// Enter implements ast.Visitor interface.
func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
return
}
// Leave implements ast.Visitor interface.
func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) {
switch v := in.(type) {
case *ast.ValueExpr:
ok = true
case *ast.BetweenExpr:
ok = e.between(v)
case *ast.BinaryOperationExpr:
ok = e.binaryOperation(v)
case *ast.WhenClause:
ok = e.whenClause(v)
case *ast.CaseExpr:
ok = e.caseExpr(v)
case *ast.SubqueryExpr:
ok = e.subquery(v)
case *ast.CompareSubqueryExpr:
ok = e.compareSubquery(v)
case *ast.ColumnNameExpr:
ok = e.columnName(v)
case *ast.DefaultExpr:
ok = e.defaultExpr(v)
case *ast.IdentifierExpr:
ok = e.identifier(v)
case *ast.ExistsSubqueryExpr:
ok = e.existsSubquery(v)
case *ast.PatternInExpr:
ok = e.patternIn(v)
case *ast.IsNullExpr:
ok = e.isNull(v)
case *ast.IsTruthExpr:
ok = e.isTruth(v)
case *ast.PatternLikeExpr:
ok = e.patternLike(v)
case *ast.ParamMarkerExpr:
ok = e.paramMarker(v)
case *ast.ParenthesesExpr:
ok = e.parentheses(v)
case *ast.PositionExpr:
ok = e.position(v)
case *ast.PatternRegexpExpr:
ok = e.patternRegexp(v)
case *ast.RowExpr:
ok = e.row(v)
case *ast.UnaryOperationExpr:
ok = e.unaryOperation(v)
case *ast.ValuesExpr:
ok = e.values(v)
case *ast.VariableExpr:
ok = e.variable(v)
case *ast.FuncCallExpr:
ok = e.funcCall(v)
case *ast.FuncExtractExpr:
ok = e.funcExtract(v)
case *ast.FuncConvertExpr:
ok = e.funcConvert(v)
case *ast.FuncCastExpr:
ok = e.funcCast(v)
case *ast.FuncSubstringExpr:
ok = e.funcSubstring(v)
case *ast.FuncLocateExpr:
ok = e.funcLocate(v)
case *ast.FuncTrimExpr:
ok = e.funcTrim(v)
case *ast.AggregateFuncExpr:
ok = e.aggregateFunc(v)
}
out = in
return
}
func checkAllOneColumn(exprs ...ast.ExprNode) bool {
for _, expr := range exprs {
switch v := expr.(type) {
case *ast.RowExpr:
return false
case *ast.SubqueryExpr:
if len(v.Query.GetResultFields()) != 1 {
return false
}
}
}
return true
}
func (e *Evaluator) between(v *ast.BetweenExpr) bool {
if !checkAllOneColumn(v.Expr, v.Left, v.Right) {
e.err = errors.Errorf("Operand should contain 1 column(s)")
return false
}
var l, r ast.ExprNode
op := opcode.AndAnd
if v.Not {
// v < lv || v > rv
op = opcode.OrOr
l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left}
r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right}
} else {
// v >= lv && v <= rv
l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left}
r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right}
}
ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r}
ret.Accept(e)
return e.err == nil
}
func columnCount(e ast.ExprNode) (int, error) {
switch x := e.(type) {
case *ast.RowExpr:
n := len(x.Values)
if n <= 1 {
return 0, errors.Errorf("Operand should contain >= 2 columns for Row")
}
return n, nil
case *ast.SubqueryExpr:
return len(x.Query.GetResultFields()), nil
default:
return 1, nil
}
}
func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error {
l, err := columnCount(e)
if err != nil {
return errors.Trace(err)
}
var n int
for _, arg := range args {
n, err = columnCount(arg)
if err != nil {
return errors.Trace(err)
}
if n != l {
return errors.Errorf("Operand should contain %d column(s)", l)
}
}
return nil
}
func (e *Evaluator) whenClause(v *ast.WhenClause) bool {
return true
}
func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool {
return true
}
func (e *Evaluator) subquery(v *ast.SubqueryExpr) bool {
return true
}
func (e *Evaluator) compareSubquery(v *ast.CompareSubqueryExpr) bool {
return true
}
func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool {
return true
}
func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool {
return true
}
func (e *Evaluator) identifier(v *ast.IdentifierExpr) bool {
return true
}
func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
return true
}
func (e *Evaluator) patternIn(v *ast.PatternInExpr) bool {
return true
}
func (e *Evaluator) isNull(v *ast.IsNullExpr) bool {
return true
}
func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
return true
}
func (e *Evaluator) patternLike(v *ast.PatternLikeExpr) bool {
return true
}
func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool {
return true
}
func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool {
return true
}
func (e *Evaluator) position(v *ast.PositionExpr) bool {
return true
}
func (e *Evaluator) patternRegexp(v *ast.PatternRegexpExpr) bool {
return true
}
func (e *Evaluator) row(v *ast.RowExpr) bool {
return true
}
func (e *Evaluator) unaryOperation(v *ast.UnaryOperationExpr) bool {
return true
}
func (e *Evaluator) values(v *ast.ValuesExpr) bool {
return true
}
func (e *Evaluator) variable(v *ast.VariableExpr) bool {
return true
}
func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool {
return true
}
func (e *Evaluator) funcExtract(v *ast.FuncExtractExpr) bool {
return true
}
func (e *Evaluator) funcConvert(v *ast.FuncConvertExpr) bool {
return true
}
func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool {
return true
}
func (e *Evaluator) funcSubstring(v *ast.FuncSubstringExpr) bool {
return true
}
func (e *Evaluator) funcLocate(v *ast.FuncLocateExpr) bool {
return true
}
func (e *Evaluator) funcTrim(v *ast.FuncTrimExpr) bool {
return true
}
func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
idx := e.aggregateMap[v]
aggr := e.aggregators[idx]
if e.aggregateDone {
v.SetValue(aggr.Output())
return true
}
// TODO: currently only single argument aggregate functions are supported.
e.err = aggr.Input(v.Args[0].GetValue())
return e.err == nil
}

View File

@ -0,0 +1,527 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"math"
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
const (
zeroI64 int64 = 0
oneI64 int64 = 1
)
func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool {
// all operands must have same column.
if e.err = hasSameColumnCount(o.L, o.R); e.err != nil {
return false
}
// row constructor only supports comparison operation.
switch o.Op {
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
default:
if !checkAllOneColumn(o.L) {
e.err = errors.Errorf("Operand should contain 1 column(s)")
return false
}
}
leftVal, err := types.Convert(o.L.GetValue(), o.GetType())
if err != nil {
e.err = err
return false
}
rightVal, err := types.Convert(o.R.GetValue(), o.GetType())
if err != nil {
e.err = err
return false
}
if leftVal == nil || rightVal == nil {
o.SetValue(nil)
return true
}
switch o.Op {
case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
return e.handleLogicOperation(o)
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
return e.handleComparisonOp(o)
case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
// TODO: MySQL doesn't support and not, we should remove it later.
return e.handleBitOp(o)
case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv:
return e.handleArithmeticOp(o)
default:
panic("should never happen")
}
}
func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool {
leftVal, err := types.Convert(o.L.GetValue(), o.GetType())
if err != nil {
e.err = err
return false
}
rightVal, err := types.Convert(o.R.GetValue(), o.GetType())
if err != nil {
e.err = err
return false
}
if leftVal == nil || rightVal == nil {
o.SetValue(nil)
return true
}
var boolVal bool
switch o.Op {
case opcode.AndAnd:
boolVal = leftVal != zeroI64 && rightVal != zeroI64
case opcode.OrOr:
boolVal = leftVal != zeroI64 || rightVal != zeroI64
case opcode.LogicXor:
boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64)
default:
panic("should never happen")
}
if boolVal {
o.SetValue(oneI64)
} else {
o.SetValue(zeroI64)
}
return true
}
func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool {
a, b := types.Coerce(o.L.GetValue(), o.R.GetValue())
if types.IsNil(a) || types.IsNil(b) {
// for <=>, if a and b are both nil, return true.
// if a or b is nil, return false.
if o.Op == opcode.NullEQ {
if types.IsNil(a) || types.IsNil(b) {
o.SetValue(oneI64)
} else {
o.SetValue(zeroI64)
}
} else {
o.SetValue(nil)
}
return true
}
n, err := types.Compare(a, b)
if err != nil {
e.err = err
return false
}
r, err := getCompResult(o.Op, n)
if err != nil {
e.err = err
return false
}
if r {
o.SetValue(oneI64)
} else {
o.SetValue(zeroI64)
}
return true
}
func getCompResult(op opcode.Op, value int) (bool, error) {
switch op {
case opcode.LT:
return value < 0, nil
case opcode.LE:
return value <= 0, nil
case opcode.GE:
return value >= 0, nil
case opcode.GT:
return value > 0, nil
case opcode.EQ:
return value == 0, nil
case opcode.NE:
return value != 0, nil
case opcode.NullEQ:
return value == 0, nil
default:
return false, errors.Errorf("invalid op %v in comparision operation", op)
}
}
func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool {
a, b := types.Coerce(o.L.GetValue(), o.R.GetValue())
if types.IsNil(a) || types.IsNil(b) {
o.SetValue(nil)
return true
}
x, err := types.ToInt64(a)
if err != nil {
e.err = err
return false
}
y, err := types.ToInt64(b)
if err != nil {
e.err = err
return false
}
// use a int64 for bit operator, return uint64
switch o.Op {
case opcode.And:
o.SetValue(uint64(x & y))
case opcode.Or:
o.SetValue(uint64(x | y))
case opcode.Xor:
o.SetValue(uint64(x ^ y))
case opcode.RightShift:
o.SetValue(uint64(x) >> uint64(y))
case opcode.LeftShift:
o.SetValue(uint64(x) << uint64(y))
default:
e.err = errors.Errorf("invalid op %v in bit operation", o.Op)
return false
}
return true
}
func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool {
a, err := coerceArithmetic(o.L.GetValue())
if err != nil {
e.err = err
return false
}
b, err := coerceArithmetic(o.R.GetValue())
if err != nil {
e.err = err
return false
}
a, b = types.Coerce(a, b)
if a == nil || b == nil {
// TODO: for <=>, if a and b are both nil, return true
o.SetValue(nil)
return true
}
// TODO: support logic division DIV
var result interface{}
switch o.Op {
case opcode.Plus:
result, e.err = computePlus(a, b)
case opcode.Minus:
result, e.err = computeMinus(a, b)
case opcode.Mul:
result, e.err = computeMul(a, b)
case opcode.Div:
result, e.err = computeDiv(a, b)
case opcode.Mod:
result, e.err = computeMod(a, b)
case opcode.IntDiv:
result, e.err = computeIntDiv(a, b)
default:
e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op)
return false
}
o.SetValue(result)
return e.err == nil
}
func computePlus(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
return types.AddInt64(x, y)
case uint64:
return types.AddInteger(y, x)
}
case uint64:
switch y := b.(type) {
case int64:
return types.AddInteger(x, y)
case uint64:
return types.AddUint64(x, y)
}
case float64:
switch y := b.(type) {
case float64:
return x + y, nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Add(y), nil
}
}
return types.InvOp2(a, b, opcode.Plus)
}
func computeMinus(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
return types.SubInt64(x, y)
case uint64:
return types.SubIntWithUint(x, y)
}
case uint64:
switch y := b.(type) {
case int64:
return types.SubUintWithInt(x, y)
case uint64:
return types.SubUint64(x, y)
}
case float64:
switch y := b.(type) {
case float64:
return x - y, nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Sub(y), nil
}
}
return types.InvOp2(a, b, opcode.Minus)
}
func computeMul(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
return types.MulInt64(x, y)
case uint64:
return types.MulInteger(y, x)
}
case uint64:
switch y := b.(type) {
case int64:
return types.MulInteger(x, y)
case uint64:
return types.MulUint64(x, y)
}
case float64:
switch y := b.(type) {
case float64:
return x * y, nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
return x.Mul(y), nil
}
}
return types.InvOp2(a, b, opcode.Mul)
}
func computeDiv(a, b interface{}) (interface{}, error) {
// MySQL support integer divison Div and division operator /
// we use opcode.Div for division operator and will use another for integer division later.
// for division operator, we will use float64 for calculation.
switch x := a.(type) {
case float64:
y, err := types.ToFloat64(b)
if err != nil {
return nil, err
}
if y == 0 {
return nil, nil
}
return x / y, nil
default:
// the scale of the result is the scale of the first operand plus
// the value of the div_precision_increment system variable (which is 4 by default)
// we will use 4 here
xa, err := types.ToDecimal(a)
if err != nil {
return nil, err
}
xb, err := types.ToDecimal(b)
if err != nil {
return nil, err
}
if f, _ := xb.Float64(); f == 0 {
// division by zero return null
return nil, nil
}
return xa.Div(xb), nil
}
}
func computeMod(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
}
return x % y, nil
case uint64:
if y == 0 {
return nil, nil
} else if x < 0 {
// first is int64, return int64.
return -int64(uint64(-x) % y), nil
}
return int64(uint64(x) % y), nil
}
case uint64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
} else if y < 0 {
// first is uint64, return uint64.
return uint64(x % uint64(-y)), nil
}
return x % uint64(y), nil
case uint64:
if y == 0 {
return nil, nil
}
return x % y, nil
}
case float64:
switch y := b.(type) {
case float64:
if y == 0 {
return nil, nil
}
return math.Mod(x, y), nil
}
case mysql.Decimal:
switch y := b.(type) {
case mysql.Decimal:
xf, _ := x.Float64()
yf, _ := y.Float64()
if yf == 0 {
return nil, nil
}
return math.Mod(xf, yf), nil
}
}
return types.InvOp2(a, b, opcode.Mod)
}
func computeIntDiv(a, b interface{}) (interface{}, error) {
switch x := a.(type) {
case int64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
}
return types.DivInt64(x, y)
case uint64:
if y == 0 {
return nil, nil
}
return types.DivIntWithUint(x, y)
}
case uint64:
switch y := b.(type) {
case int64:
if y == 0 {
return nil, nil
}
return types.DivUintWithInt(x, y)
case uint64:
if y == 0 {
return nil, nil
}
return x / y, nil
}
}
// if any is none integer, use decimal to calculate
x, err := types.ToDecimal(a)
if err != nil {
return nil, err
}
y, err := types.ToDecimal(b)
if err != nil {
return nil, err
}
if f, _ := y.Float64(); f == 0 {
return nil, nil
}
return x.Div(y).IntPart(), nil
}
func coerceArithmetic(a interface{}) (interface{}, error) {
switch x := a.(type) {
case string:
// MySQL will convert string to float for arithmetic operation
f, err := types.StrToFloat(x)
if err != nil {
return nil, err
}
return f, err
case mysql.Time:
// if time has no precision, return int64
v := x.ToNumber()
if x.Fsp == 0 {
return v.IntPart(), nil
}
return v, nil
case mysql.Duration:
// if duration has no precision, return int64
v := x.ToNumber()
if x.Fsp == 0 {
return v.IntPart(), nil
}
return v, nil
case []byte:
// []byte is the same as string, converted to float for arithmetic operator.
f, err := types.StrToFloat(string(x))
if err != nil {
return nil, err
}
return f, err
case mysql.Hex:
return x.ToNumber(), nil
case mysql.Bit:
return x.ToNumber(), nil
case mysql.Enum:
return x.ToNumber(), nil
case mysql.Set:
return x.ToNumber(), nil
default:
return x, nil
}
}

441
optimizer/infobinder.go Normal file
View File

@ -0,0 +1,441 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/infoschema"
"github.com/pingcap/tidb/model"
)
// InfoBinder binds schema information for table name and column name and set result fields
// for ResetSetNode.
// We need to know which table a table name refers to, which column a column name refers to.
//
// In general, a reference can only refer to information that are available for it.
// So children elements are visited in the order that previous elements make information
// available for following elements.
//
// During visiting, information are collected and stored in binderContext.
// When we enter a sub query, a new binderContext is pushed to the contextStack, so sub query
// information can overwrite outer query information. When we look up for a column reference,
// we look up from top to bottom in the contextStack.
type InfoBinder struct {
Info infoschema.InfoSchema
DefaultSchema model.CIStr
Err error
contextStack []*binderContext
}
// binderContext stores information that table name and column name
// can be bind to.
type binderContext struct {
/* For Select Statement. */
// table map to lookup and check table name conflict.
tableMap map[string]int
// tableSources collected in from clause.
tables []*ast.TableSource
// result fields collected in select field list.
fieldList []*ast.ResultField
// result fields collected in group by clause.
groupBy []*ast.ResultField
// The join node stack is used by on condition to find out
// available tables to reference. On condition can only
// refer to tables involved in current join.
joinNodeStack []*ast.Join
// When visiting TableRefs, tables in this context are not available
// because it is being collected.
inTableRefs bool
// When visiting on conditon only tables in current join node are available.
inOnCondition bool
// When visiting field list, fieldList in this context are not available.
inFieldList bool
// When visiting group by, groupBy fields are not available.
inGroupBy bool
// When visiting having, only fieldList and groupBy fields are available.
inHaving bool
}
// currentContext gets the current binder context.
func (sb *InfoBinder) currentContext() *binderContext {
stackLen := len(sb.contextStack)
if stackLen == 0 {
return nil
}
return sb.contextStack[stackLen-1]
}
// pushContext is called when we enter a statement.
func (sb *InfoBinder) pushContext() {
sb.contextStack = append(sb.contextStack, &binderContext{
tableMap: map[string]int{},
})
}
// popContext is called when we leave a statement.
func (sb *InfoBinder) popContext() {
sb.contextStack = sb.contextStack[:len(sb.contextStack)-1]
}
// pushJoin is called when we enter a join node.
func (sb *InfoBinder) pushJoin(j *ast.Join) {
ctx := sb.currentContext()
ctx.joinNodeStack = append(ctx.joinNodeStack, j)
}
// popJoin is called when we leave a join node.
func (sb *InfoBinder) popJoin() {
ctx := sb.currentContext()
ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1]
}
// Enter implements ast.Visitor interface.
func (sb *InfoBinder) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
switch v := inNode.(type) {
case *ast.SelectStmt:
sb.pushContext()
case *ast.TableRefsClause:
sb.currentContext().inTableRefs = true
case *ast.Join:
sb.pushJoin(v)
case *ast.OnCondition:
sb.currentContext().inOnCondition = true
case *ast.FieldList:
sb.currentContext().inFieldList = true
case *ast.GroupByClause:
sb.currentContext().inGroupBy = true
case *ast.HavingClause:
sb.currentContext().inHaving = true
case *ast.InsertStmt:
sb.pushContext()
case *ast.DeleteStmt:
sb.pushContext()
case *ast.UpdateStmt:
sb.pushContext()
}
return inNode, false
}
// Leave implements ast.Visitor interface.
func (sb *InfoBinder) Leave(inNode ast.Node) (node ast.Node, ok bool) {
switch v := inNode.(type) {
case *ast.TableName:
sb.handleTableName(v)
case *ast.ColumnName:
sb.handleColumnName(v)
case *ast.TableSource:
sb.handleTableSource(v)
case *ast.OnCondition:
sb.currentContext().inOnCondition = false
case *ast.Join:
sb.handleJoin(v)
sb.popJoin()
case *ast.TableRefsClause:
sb.currentContext().inTableRefs = false
case *ast.FieldList:
sb.handleFieldList(v)
sb.currentContext().inFieldList = false
case *ast.GroupByClause:
sb.currentContext().inGroupBy = false
case *ast.HavingClause:
sb.currentContext().inHaving = false
case *ast.SelectStmt:
v.SetResultFields(sb.currentContext().fieldList)
sb.popContext()
case *ast.InsertStmt:
sb.popContext()
case *ast.DeleteStmt:
sb.popContext()
case *ast.UpdateStmt:
sb.popContext()
}
return inNode, sb.Err == nil
}
// handleTableName looks up and bind the schema information for table name
// and set result fields for table name.
func (sb *InfoBinder) handleTableName(tn *ast.TableName) {
if tn.Schema.L == "" {
tn.Schema = sb.DefaultSchema
}
table, err := sb.Info.TableByName(tn.Schema, tn.Name)
if err != nil {
sb.Err = err
return
}
tn.TableInfo = table.Meta()
dbInfo, _ := sb.Info.SchemaByName(tn.Schema)
tn.DBInfo = dbInfo
rfs := make([]*ast.ResultField, len(tn.TableInfo.Columns))
for i, v := range tn.TableInfo.Columns {
rfs[i] = &ast.ResultField{
Column: v,
Table: tn.TableInfo,
DBName: tn.Schema,
}
}
tn.SetResultFields(rfs)
return
}
// handleTableSources checks name duplication
// and puts the table source in current binderContext.
func (sb *InfoBinder) handleTableSource(ts *ast.TableSource) {
for _, v := range ts.GetResultFields() {
v.TableAsName = ts.AsName
}
var name string
if ts.AsName.L != "" {
name = ts.AsName.L
} else {
tableName := ts.Source.(*ast.TableName)
name = sb.tableUniqueName(tableName.Schema, tableName.Name)
}
ctx := sb.currentContext()
if _, ok := ctx.tableMap[name]; ok {
sb.Err = errors.Errorf("duplicated table/alias name %s", name)
return
}
ctx.tableMap[name] = len(ctx.tables)
ctx.tables = append(ctx.tables, ts)
return
}
// handleJoin sets result fields for join.
func (sb *InfoBinder) handleJoin(j *ast.Join) {
if j.Right == nil {
j.SetResultFields(j.Left.GetResultFields())
return
}
leftLen := len(j.Left.GetResultFields())
rightLen := len(j.Right.GetResultFields())
rfs := make([]*ast.ResultField, leftLen+rightLen)
copy(rfs, j.Left.GetResultFields())
copy(rfs[leftLen:], j.Right.GetResultFields())
j.SetResultFields(rfs)
}
// handleColumnName looks up and binds schema information to
// the column name.
func (sb *InfoBinder) handleColumnName(cn *ast.ColumnName) {
ctx := sb.currentContext()
if ctx.inOnCondition {
// In on condition, only tables within current join is available.
sb.bindColumnNameInOnCondition(cn)
return
}
// Try to bind the column name form top to bottom in the context stack.
for i := len(sb.contextStack) - 1; i >= 0; i-- {
if sb.bindColumnNameInContext(sb.contextStack[i], cn) {
// Column is already bound or encountered an error.
return
}
}
sb.Err = errors.Errorf("Unknown column %s", cn.Name.L)
}
// bindColumnNameInContext looks up and binds schema information for a column with the ctx.
func (sb *InfoBinder) bindColumnNameInContext(ctx *binderContext, cn *ast.ColumnName) (done bool) {
if cn.Table.L == "" {
// If qualified table name is not specified in column name, the column name may be ambiguous,
// We need to iterate over all tables and
}
if ctx.inTableRefs {
// In TableRefsClause, column reference only in join on condition which is handled before.
return false
}
if ctx.inFieldList {
// only bind column using tables.
return sb.bindColumnInTableSources(cn, ctx.tables)
}
if ctx.inGroupBy {
// field list first, then tables.
if sb.bindColumnInResultFields(cn, ctx.fieldList) {
return true
}
return sb.bindColumnInTableSources(cn, ctx.tables)
}
// column name in other places can be looked up in the same order.
if sb.bindColumnInResultFields(cn, ctx.groupBy) {
return true
}
if sb.bindColumnInResultFields(cn, ctx.fieldList) {
return true
}
// tables is not available for having clause.
if !ctx.inHaving {
return sb.bindColumnInTableSources(cn, ctx.tables)
}
return false
}
// bindColumnNameInOnCondition looks up for column name in current join, and
// binds the schema information.
func (sb *InfoBinder) bindColumnNameInOnCondition(cn *ast.ColumnName) {
ctx := sb.currentContext()
join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1]
tableSources := appendTableSources(nil, join)
if !sb.bindColumnInTableSources(cn, tableSources) {
sb.Err = errors.Errorf("unkown column name %s", cn.Name.O)
}
}
func (sb *InfoBinder) bindColumnInTableSources(cn *ast.ColumnName, tableSources []*ast.TableSource) (done bool) {
var matchedResultField *ast.ResultField
if cn.Table.L != "" {
var matchedTable ast.ResultSetNode
for _, ts := range tableSources {
if cn.Table.L == ts.AsName.L {
// different table name.
matchedTable = ts
break
}
if tn, ok := ts.Source.(*ast.TableName); ok {
if cn.Table.L == tn.Name.L {
matchedTable = ts
}
}
}
if matchedTable != nil {
resultFields := matchedTable.GetResultFields()
for _, rf := range resultFields {
if rf.ColumnAsName.L == cn.Name.L || rf.Column.Name.L == cn.Name.L {
// bind column.
matchedResultField = rf
break
}
}
}
} else {
for _, ts := range tableSources {
rfs := ts.GetResultFields()
for _, rf := range rfs {
matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L
matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L
if matchAsName || matchColumnName {
if matchedResultField != nil {
sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O)
return true
}
matchedResultField = rf
}
}
}
}
if matchedResultField != nil {
// bind column.
cn.ColumnInfo = matchedResultField.Column
cn.TableInfo = matchedResultField.Table
return true
}
return false
}
func (sb *InfoBinder) bindColumnInResultFields(cn *ast.ColumnName, rfs []*ast.ResultField) bool {
var matchedResultField *ast.ResultField
for _, rf := range rfs {
matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L
matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L
if matchAsName || matchColumnName {
if matchedResultField != nil {
sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O)
return false
}
matchedResultField = rf
}
}
if matchedResultField != nil {
// bind column.
cn.ColumnInfo = matchedResultField.Column
cn.TableInfo = matchedResultField.Table
return true
}
return false
}
// handleFieldList expands wild card field and set fieldList in current context.
func (sb *InfoBinder) handleFieldList(fieldList *ast.FieldList) {
var resultFields []*ast.ResultField
for _, v := range fieldList.Fields {
resultFields = append(resultFields, sb.createResultFields(v)...)
}
sb.currentContext().fieldList = resultFields
}
// createResultFields creates result field list for a single select field.
func (sb *InfoBinder) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) {
ctx := sb.currentContext()
if field.WildCard != nil {
if len(ctx.tables) == 0 {
sb.Err = errors.Errorf("No table used.")
return
}
if field.WildCard.Table.L == "" {
for _, v := range ctx.tables {
rfs = append(rfs, v.GetResultFields()...)
}
} else {
name := sb.tableUniqueName(field.WildCard.Schema, field.WildCard.Table)
tableIdx, ok := ctx.tableMap[name]
if !ok {
sb.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O)
}
rfs = ctx.tables[tableIdx].GetResultFields()
}
return
}
// The column is visited before so it must has been bound already.
rf := &ast.ResultField{ColumnAsName: field.AsName}
switch v := field.Expr.(type) {
case *ast.ColumnNameExpr:
rf.Column = v.Name.ColumnInfo
rf.Table = v.Name.TableInfo
rf.DBName = v.Name.Schema
default:
if field.AsName.L == "" {
rf.ColumnAsName.L = field.Expr.Text()
rf.ColumnAsName.O = rf.ColumnAsName.L
}
}
rfs = append(rfs, rf)
return
}
func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) {
switch v := resultSetNode.(type) {
case *ast.TableSource:
out = append(in, v)
case *ast.Join:
out = appendTableSources(in, v.Left)
if v.Right != nil {
out = appendTableSources(out, v.Right)
}
}
return
}
func (sb *InfoBinder) tableUniqueName(schema, table model.CIStr) string {
if schema.L != "" && schema.L != sb.DefaultSchema.L {
return schema.L + "." + table.L
}
return table.L
}

View File

@ -0,0 +1,82 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer_test
import (
"testing"
. "github.com/pingcap/check"
"github.com/pingcap/tidb"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/optimizer"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/testkit"
)
func TestT(t *testing.T) {
TestingT(t)
}
var _ = Suite(&testInfoBinderSuite{})
type testInfoBinderSuite struct {
}
type binderVerifier struct {
c *C
}
func (bv *binderVerifier) Enter(node ast.Node) (ast.Node, bool) {
return node, false
}
func (bv *binderVerifier) Leave(in ast.Node) (out ast.Node, ok bool) {
switch v := in.(type) {
case *ast.ColumnName:
bv.c.Assert(v.ColumnInfo, NotNil)
case *ast.TableName:
bv.c.Assert(v.TableInfo, NotNil)
}
return in, true
}
func (ts *testInfoBinderSuite) TestInfoBinder(c *C) {
store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory)
c.Assert(err, IsNil)
defer store.Close()
testKit := testkit.NewTestKit(c, store)
testKit.MustExec("use test")
testKit.MustExec("create table t (c1 int, c2 int)")
domain := sessionctx.GetDomain(testKit.Se.(context.Context))
src := "SELECT c1 from t"
l := parser.NewLexer(src)
c.Assert(parser.YYParse(l), Equals, 0)
stmts := l.Stmts()
c.Assert(len(stmts), Equals, 1)
v := &optimizer.InfoBinder{
Info: domain.InfoSchema(),
DefaultSchema: model.NewCIStr("test"),
}
selectStmt := stmts[0].(*ast.SelectStmt)
selectStmt.Accept(v)
verifier := &binderVerifier{
c: c,
}
selectStmt.Accept(verifier)
}

View File

@ -13,20 +13,18 @@
package optimizer
import (
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/stmt"
)
import "github.com/pingcap/tidb/ast"
// Compile compiles a ast.Node into a executable statement.
func Compile(node ast.Node) (stmt.Statement, error) {
switch v := node.(type) {
case *ast.SetStmt:
return compileSet(v)
}
return nil, nil
// typeComputer is an ast Visitor that
// computes result type for ast.ExprNode.
type typeComputer struct {
err error
}
func compileSet(aset *ast.SetStmt) (stmt.Statement, error) {
return nil, nil
func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
return in, false
}
func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

30
optimizer/validator.go Normal file
View File

@ -0,0 +1,30 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package optimizer
import "github.com/pingcap/tidb/ast"
// validator is an ast.Visitor that validates
// ast parsed from parser.
type validator struct {
err error
}
func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
return in, false
}
func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}

File diff suppressed because it is too large Load Diff

View File

@ -18,9 +18,7 @@ import (
"testing"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/subquery"
"github.com/pingcap/tidb/stmt/stmts"
"github.com/pingcap/tidb/ast"
)
func TestT(t *testing.T) {
@ -32,23 +30,6 @@ var _ = Suite(&testParserSuite{})
type testParserSuite struct {
}
func (s *testParserSuite) TestOriginText(c *C) {
src := `SELECT stuff.id
FROM stuff
WHERE stuff.value >= ALL (SELECT stuff.value
FROM stuff)`
l := NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
node := l.Stmts()[0].(*stmts.SelectStmt)
sq := node.Where.Expr.(*expression.CompareSubQuery).R
c.Assert(sq, NotNil)
subsel := sq.(*subquery.SubQuery)
c.Assert(subsel.Stmt.OriginText(), Equals,
`SELECT stuff.value
FROM stuff`)
}
func (s *testParserSuite) TestSimple(c *C) {
// Testcase for unreserved keywords
unreservedKws := []string{
@ -70,15 +51,12 @@ func (s *testParserSuite) TestSimple(c *C) {
// Testcase for prepared statement
src := "SELECT id+?, id+? from t;"
l := NewLexer(src)
l.SetPrepare()
c.Assert(yyParse(l), Equals, 0)
c.Assert(len(l.ParamList), Equals, 2)
c.Assert(len(l.Stmts()), Equals, 1)
// Testcase for -- Comment and unary -- operator
src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;"
l = NewLexer(src)
l.SetPrepare()
c.Assert(yyParse(l), Equals, 0)
c.Assert(len(l.Stmts()), Equals, 2)
@ -87,11 +65,12 @@ func (s *testParserSuite) TestSimple(c *C) {
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
st := l.Stmts()[0]
ss, ok := st.(*stmts.SelectStmt)
ss, ok := st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
cv, ok := ss.Fields[0].Expr.(*expression.FunctionCast)
c.Assert(len(ss.Fields.Fields), Equals, 1)
cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr)
c.Assert(ok, IsTrue)
c.Assert(cv.FunctionType, Equals, expression.ConvertFunction)
c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction)
// For query start with comment
srcs := []string{
@ -105,7 +84,7 @@ func (s *testParserSuite) TestSimple(c *C) {
l = NewLexer(src)
c.Assert(yyParse(l), Equals, 0)
st = l.Stmts()[0]
ss, ok = st.(*stmts.SelectStmt)
ss, ok = st.(*ast.SelectStmt)
c.Assert(ok, IsTrue)
}
}
@ -172,9 +151,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
{"REPLACE INTO foo () VALUES ()", true},
{"REPLACE INTO foo VALUE ()", true},
// 40
{`SELECT stuff.id
FROM stuff
WHERE stuff.value >= ALL (SELECT stuff.value
{`SELECT stuff.id
FROM stuff
WHERE stuff.value >= ALL (SELECT stuff.value
FROM stuff)`, true},
{"BEGIN", true},
{"START TRANSACTION", true},

View File

@ -22,25 +22,25 @@ import (
"fmt"
"math"
"strconv"
"unicode"
"strings"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/util/stringutil"
"github.com/pingcap/tidb/util/charset"
"unicode"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/util/charset"
"github.com/pingcap/tidb/util/stringutil"
)
type lexer struct {
c int
col int
errs []error
expr expression.Expression
expr ast.ExprNode
i int
inj int
lcol int
line int
list []interface{}
list []ast.StmtNode
ncol int
nline int
sc int
@ -48,8 +48,7 @@ type lexer struct {
val []byte
ungetBuf []byte
root bool
prepare bool
ParamList []*expression.ParamMarker
prepare bool
stmtStartPos int
stringLit []byte
@ -78,11 +77,11 @@ func (l *lexer) Errors() []error {
return l.errs
}
func (l *lexer) Stmts() []interface{}{
func (l *lexer) Stmts() []ast.StmtNode {
return l.list
}
func (l *lexer) Expr() expression.Expression {
func (l *lexer) Expr() ast.ExprNode {
return l.expr
}
@ -90,16 +89,16 @@ func (l *lexer) Inj() int {
return l.inj
}
func (l *lexer) SetInj(inj int) {
l.inj = inj
}
func (l *lexer) SetPrepare() {
l.prepare = true
l.prepare = true
}
func (l *lexer) IsPrepare() bool {
return l.prepare
}
func (l *lexer) SetInj(inj int) {
l.inj = inj
return l.prepare
}
func (l *lexer) Root() bool {
@ -116,7 +115,17 @@ func (l *lexer) SetCharsetInfo(charset, collation string) {
}
func (l *lexer) GetCharsetInfo() (string, string) {
return l.charset, l.collation
return l.charset, l.collation
}
// The select statement is not at the end of the whole statement, if the last
// field text was set from its offset to the end of the src string, update
// the last field text.
func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) {
lastField := st.Fields.Fields[len(st.Fields.Fields)-1]
if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 {
lastField.SetText(l.src[lastField.Offset:lastEnd])
}
}
func (l *lexer) unget(b byte) {
@ -822,9 +831,9 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{repeat} lval.item = string(l.val)
return repeat
{regexp} return regexp
{references} return references
{replace} lval.item = string(l.val)
return replace
{references} return references
{rlike} return rlike
{sys_var} lval.item = string(l.val)

View File

@ -1,92 +0,0 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package parser
import (
"fmt"
"unicode"
. "github.com/pingcap/check"
)
func tok2name(i int) string {
if i == unicode.ReplacementChar {
return "<?>"
}
if i < 128 {
return fmt.Sprintf("tok-'%c'", i)
}
return fmt.Sprintf("tok-%d", i)
}
func (s *testParserSuite) TestScaner0(c *C) {
table := []struct {
src string
tok, line, col, nline, ncol int
val string
}{
{"a", identifier, 1, 1, 1, 2, "a"},
{" a", identifier, 1, 2, 1, 3, "a"},
{"a ", identifier, 1, 1, 1, 2, "a"},
{" a ", identifier, 1, 2, 1, 3, "a"},
{"\na", identifier, 2, 1, 2, 2, "a"},
{"a\n", identifier, 1, 1, 1, 2, "a"},
{"\na\n", identifier, 2, 1, 2, 2, "a"},
{"\n a", identifier, 2, 2, 2, 3, "a"},
{"a \n", identifier, 1, 1, 1, 2, "a"},
{"\n a \n", identifier, 2, 2, 2, 3, "a"},
{"ab", identifier, 1, 1, 1, 3, "ab"},
{" ab", identifier, 1, 2, 1, 4, "ab"},
{"ab ", identifier, 1, 1, 1, 3, "ab"},
{" ab ", identifier, 1, 2, 1, 4, "ab"},
{"\nab", identifier, 2, 1, 2, 3, "ab"},
{"ab\n", identifier, 1, 1, 1, 3, "ab"},
{"\nab\n", identifier, 2, 1, 2, 3, "ab"},
{"\n ab", identifier, 2, 2, 2, 4, "ab"},
{"ab \n", identifier, 1, 1, 1, 3, "ab"},
{"\n ab \n", identifier, 2, 2, 2, 4, "ab"},
{"c", identifier, 1, 1, 1, 2, "c"},
{"cR", identifier, 1, 1, 1, 3, "cR"},
{"cRe", identifier, 1, 1, 1, 4, "cRe"},
{"cReA", identifier, 1, 1, 1, 5, "cReA"},
{"cReAt", identifier, 1, 1, 1, 6, "cReAt"},
{"cReATe", create, 1, 1, 1, 7, "cReATe"},
{"cReATeD", identifier, 1, 1, 1, 8, "cReATeD"},
{"2", intLit, 1, 1, 1, 2, "2"},
{"2.", floatLit, 1, 1, 1, 3, "2."},
{"2.3", floatLit, 1, 1, 1, 4, "2.3"},
}
lval := &yySymType{}
for _, t := range table {
l := NewLexer(t.src)
tok := l.Lex(lval)
nline, ncol := l.npos()
val := string(l.val)
c.Assert(tok, Equals, t.tok)
c.Assert(l.line, Equals, t.line)
c.Assert(l.col, Equals, t.col)
c.Assert(nline, Equals, t.nline)
c.Assert(ncol, Equals, t.ncol)
c.Assert(val, Equals, t.val)
}
}

View File

@ -62,7 +62,7 @@ func (s *testStmtSuite) TestDropTable(c *C) {
}
func (s *testStmtSuite) TestDropIndex(c *C) {
testSQL := "drop index if exists drop_index;"
testSQL := "drop index if exists drop_index on t;"
stmtList, err := tidb.Compile(s.ctx, testSQL)
c.Assert(err, IsNil)

View File

@ -51,5 +51,5 @@ func (s *testStmtSuite) TestShow(c *C) {
c.Assert(stmtList, HasLen, 1)
testStmt, ok = stmtList[0].(*stmts.ShowStmt)
c.Assert(ok, IsTrue)
c.Assert(testStmt.Pattern, NotNil)
c.Assert(testStmt.Pattern, NotNil, Commentf("S: %s", testStmt.Text))
}

21
tidb.go
View File

@ -26,7 +26,6 @@ import (
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/domain"
"github.com/pingcap/tidb/expression"
@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) {
rawStmt := l.Stmts()
stmts := make([]stmt.Statement, len(rawStmt))
for i, v := range rawStmt {
if node, ok := v.(ast.Node); ok {
stm, err := optimizer.Compile(node)
if err != nil {
return nil, errors.Trace(err)
}
stmts[i] = stm
} else {
stmts[i] = v.(stmt.Statement)
compiler := &optimizer.Compiler{}
stm, err := compiler.Compile(v)
if err != nil {
return nil, errors.Trace(err)
}
stmts[i] = stm
}
return stmts, nil
}
@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express
return nil, nil, nil
}
sm := sms[0]
return sm.(stmt.Statement), l.ParamList, nil
compiler := &optimizer.Compiler{}
statement, err := compiler.Compile(sm)
if err != nil {
return nil, nil, errors.Trace(err)
}
return statement, compiler.ParamMarkers(), nil
}
func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) {

View File

@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) {
// conflict
mustExecSQL(c, se1, "begin")
rs, err := exec(c, se1, "select * from t where c1=11 for update")
c.Assert(err, IsNil)
_, err = rs.Rows(-1, 0)
mustExecSQL(c, se2, "begin")