Merge branch 'master' into coocood/use-terror
This commit is contained in:
22
Makefile
22
Makefile
@ -13,7 +13,7 @@ TARGET = ""
|
||||
|
||||
.PHONY: godep deps all build install parser clean todo test gotest interpreter server
|
||||
|
||||
all: godep parser ast-parser build test check
|
||||
all: godep parser build test check
|
||||
|
||||
godep:
|
||||
go get github.com/tools/godep
|
||||
@ -31,7 +31,7 @@ parser:
|
||||
goyacc -o /dev/null -xegen $$a parser/parser.y; \
|
||||
goyacc -o parser/parser.go -xe $$a parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \
|
||||
rm -f $$a; \
|
||||
rm -f y.output
|
||||
rm -f y.output
|
||||
|
||||
@if [ $(ARCH) = $(LINUX) ]; \
|
||||
then \
|
||||
@ -44,24 +44,6 @@ parser:
|
||||
|
||||
golex -o parser/scanner.go parser/scanner.l
|
||||
|
||||
ast-parser:
|
||||
a=`mktemp temp.XXXXXX`; \
|
||||
goyacc -o /dev/null -xegen $$a ast/parser/parser.y; \
|
||||
goyacc -o ast/parser/parser.go -xe $$a ast/parser/parser.y 2>&1 | grep "shift/reduce" | awk '{print} END {if (NR > 0) {print "Find conflict in parser.y. Please check y.output for more information."; exit 1;}}'; \
|
||||
rm -f $$a; \
|
||||
rm -f y.output
|
||||
|
||||
@if [ $(ARCH) = $(LINUX) ]; \
|
||||
then \
|
||||
sed -i -e 's|//line.*||' -e 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \
|
||||
elif [ $(ARCH) = $(MAC) ]; \
|
||||
then \
|
||||
/usr/bin/sed -i "" 's|//line.*||' ast/parser/parser.go; \
|
||||
/usr/bin/sed -i "" 's/yyEofCode/yyEOFCode/' ast/parser/parser.go; \
|
||||
fi
|
||||
|
||||
golex -o ast/parser/scanner.go ast/parser/scanner.l
|
||||
|
||||
check:
|
||||
go get github.com/golang/lint/golint
|
||||
|
||||
|
||||
14
ast/ast.go
14
ast/ast.go
@ -26,6 +26,11 @@ type Node interface {
|
||||
// Accept accepts Visitor to visit itself.
|
||||
// The returned node should replace original node.
|
||||
// ok returns false to stop visiting.
|
||||
//
|
||||
// Implementation of this method should first call visitor.Enter,
|
||||
// assign the returned node to its method receiver, if skipChildren returns true,
|
||||
// children should be skipped. Otherwise, call its children in particular order that
|
||||
// later elements depends on former elements. Finally, return visitor.Leave.
|
||||
Accept(v Visitor) (node Node, ok bool)
|
||||
// Text returns the original text of the element.
|
||||
Text() string
|
||||
@ -44,6 +49,10 @@ type ExprNode interface {
|
||||
SetType(tp *types.FieldType)
|
||||
// GetType gets the evaluation type of the expression.
|
||||
GetType() *types.FieldType
|
||||
// SetValue sets value to the expression.
|
||||
SetValue(val interface{})
|
||||
// GetValue gets value of the expression.
|
||||
GetValue() interface{}
|
||||
}
|
||||
|
||||
// FuncNode represents function call expression node.
|
||||
@ -93,11 +102,12 @@ type ResultSetNode interface {
|
||||
// Visitor visits a Node.
|
||||
type Visitor interface {
|
||||
// VisitEnter is called before children nodes is visited.
|
||||
// The returned node must be the same type as the input node n.
|
||||
// skipChildren returns true means children nodes should be skipped,
|
||||
// this is useful when work is done in Enter and there is no need to visit children.
|
||||
// ok returns false to stop visiting.
|
||||
Enter(n Node) (skipChildren bool, ok bool)
|
||||
Enter(n Node) (node Node, skipChildren bool)
|
||||
// VisitLeave is called after children nodes has been visited.
|
||||
// The returned node must be the same type as the input node n.
|
||||
// ok returns false to stop visiting.
|
||||
Leave(n Node) (node Node, ok bool)
|
||||
}
|
||||
|
||||
16
ast/base.go
16
ast/base.go
@ -62,7 +62,7 @@ func (dn *dmlNode) dmlStatement() {}
|
||||
// Expression implementations should embed it in.
|
||||
type exprNode struct {
|
||||
node
|
||||
tp *types.FieldType
|
||||
types.DataItem
|
||||
}
|
||||
|
||||
// IsStatic implements Expression interface.
|
||||
@ -72,12 +72,22 @@ func (en *exprNode) IsStatic() bool {
|
||||
|
||||
// SetType implements Expression interface.
|
||||
func (en *exprNode) SetType(tp *types.FieldType) {
|
||||
en.tp = tp
|
||||
en.Type = tp
|
||||
}
|
||||
|
||||
// GetType implements Expression interface.
|
||||
func (en *exprNode) GetType() *types.FieldType {
|
||||
return en.tp
|
||||
return en.Type
|
||||
}
|
||||
|
||||
// SetValue implements Expression interface.
|
||||
func (en *exprNode) SetValue(val interface{}) {
|
||||
en.Data = val
|
||||
}
|
||||
|
||||
// GetValue implements Expression interface.
|
||||
func (en *exprNode) GetValue() interface{} {
|
||||
return en.Data
|
||||
}
|
||||
|
||||
type funcNode struct {
|
||||
|
||||
173
ast/cloner.go
Normal file
173
ast/cloner.go
Normal file
@ -0,0 +1,173 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package ast
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Cloner is an ast visitor that clones a node.
|
||||
type Cloner struct {
|
||||
}
|
||||
|
||||
// Enter implements Visitor Enter interface.
|
||||
func (c *Cloner) Enter(node Node) (Node, bool) {
|
||||
return copyStruct(node), false
|
||||
}
|
||||
|
||||
// Leave implements Visitor Leave interface.
|
||||
func (c *Cloner) Leave(in Node) (out Node, ok bool) {
|
||||
return in, true
|
||||
}
|
||||
|
||||
// copyStruct copies a node's struct value, if the struct has slice member,
|
||||
// make a new slice and copy old slice value to new slice.
|
||||
func copyStruct(in Node) (out Node) {
|
||||
switch v := in.(type) {
|
||||
case *ValueExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *BetweenExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *BinaryOperationExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *WhenClause:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *CaseExpr:
|
||||
nv := *v
|
||||
nv.WhenClauses = make([]*WhenClause, len(v.WhenClauses))
|
||||
copy(nv.WhenClauses, v.WhenClauses)
|
||||
out = &nv
|
||||
case *SubqueryExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *CompareSubqueryExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ColumnName:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ColumnNameExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *DefaultExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *IdentifierExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ExistsSubqueryExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *PatternInExpr:
|
||||
nv := *v
|
||||
nv.List = make([]ExprNode, len(v.List))
|
||||
copy(nv.List, v.List)
|
||||
out = &nv
|
||||
case *IsNullExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *IsTruthExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *PatternLikeExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ParamMarkerExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ParenthesesExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *PositionExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *PatternRegexpExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *RowExpr:
|
||||
nv := *v
|
||||
nv.Values = make([]ExprNode, len(v.Values))
|
||||
copy(nv.Values, v.Values)
|
||||
out = &nv
|
||||
case *UnaryOperationExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ValuesExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *VariableExpr:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *Join:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *TableName:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *TableSource:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *OnCondition:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *WildCardField:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *SelectField:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *FieldList:
|
||||
nv := *v
|
||||
nv.Fields = make([]*SelectField, len(v.Fields))
|
||||
copy(nv.Fields, v.Fields)
|
||||
out = &nv
|
||||
case *TableRefsClause:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *ByItem:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *GroupByClause:
|
||||
nv := *v
|
||||
nv.Items = make([]*ByItem, len(v.Items))
|
||||
copy(nv.Items, v.Items)
|
||||
out = &nv
|
||||
case *HavingClause:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *OrderByClause:
|
||||
nv := *v
|
||||
nv.Items = make([]*ByItem, len(v.Items))
|
||||
copy(nv.Items, v.Items)
|
||||
out = &nv
|
||||
case *SelectStmt:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *UnionClause:
|
||||
nv := *v
|
||||
out = &nv
|
||||
case *UnionStmt:
|
||||
nv := *v
|
||||
nv.Selects = make([]*SelectStmt, len(v.Selects))
|
||||
copy(nv.Selects, v.Selects)
|
||||
out = &nv
|
||||
default:
|
||||
// We currently only handle expression and select statement.
|
||||
// Will add more when we need to.
|
||||
panic("unknown ast Node type " + fmt.Sprintf("%T", v))
|
||||
}
|
||||
return
|
||||
}
|
||||
40
ast/cloner_test.go
Normal file
40
ast/cloner_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
)
|
||||
|
||||
func TestT(t *testing.T) {
|
||||
TestingT(t)
|
||||
}
|
||||
|
||||
var _ = Suite(&testClonerSuite{})
|
||||
|
||||
type testClonerSuite struct {
|
||||
}
|
||||
|
||||
func (ts *testClonerSuite) TestCloner(c *C) {
|
||||
cloner := &Cloner{}
|
||||
|
||||
a := &UnaryOperationExpr{
|
||||
Op: opcode.Not,
|
||||
V: &UnaryOperationExpr{V: NewValueExpr(true)},
|
||||
}
|
||||
|
||||
b, ok := a.Accept(cloner)
|
||||
c.Assert(ok, IsTrue)
|
||||
a1 := a.V
|
||||
b1 := b.(*UnaryOperationExpr).V
|
||||
c.Assert(a1, Not(Equals), b1)
|
||||
a2 := a1.(*UnaryOperationExpr).V
|
||||
b2 := b1.(*UnaryOperationExpr).V
|
||||
c.Assert(a2, Not(Equals), b2)
|
||||
a3 := a2.(*ValueExpr)
|
||||
b3 := b2.(*ValueExpr)
|
||||
c.Assert(a3, Not(Equals), b3)
|
||||
c.Assert(a3.GetValue(), Equals, true)
|
||||
c.Assert(b3.GetValue(), Equals, true)
|
||||
}
|
||||
304
ast/ddl.go
304
ast/ddl.go
@ -70,11 +70,13 @@ type CreateDatabaseStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cd *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cd); skipChildren {
|
||||
return cd, ok
|
||||
func (n *CreateDatabaseStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(cd)
|
||||
n = newNod.(*CreateDatabaseStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DropDatabaseStmt is a statement to drop a database and all tables in the database.
|
||||
@ -87,11 +89,13 @@ type DropDatabaseStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (dd *DropDatabaseStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(dd); skipChildren {
|
||||
return dd, ok
|
||||
func (n *DropDatabaseStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(dd)
|
||||
n = newNod.(*DropDatabaseStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IndexColName is used for parsing index column name from SQL.
|
||||
@ -103,16 +107,18 @@ type IndexColName struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ic *IndexColName) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ic); skipChildren {
|
||||
return ic, ok
|
||||
func (n *IndexColName) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ic.Column.Accept(v)
|
||||
n = newNod.(*IndexColName)
|
||||
node, ok := n.Column.Accept(v)
|
||||
if !ok {
|
||||
return ic, false
|
||||
return n, false
|
||||
}
|
||||
ic.Column = node.(*ColumnName)
|
||||
return v.Leave(ic)
|
||||
n.Column = node.(*ColumnName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ReferenceDef is used for parsing foreign key reference option from SQL.
|
||||
@ -125,23 +131,25 @@ type ReferenceDef struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (rd *ReferenceDef) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(rd); skipChildren {
|
||||
return rd, ok
|
||||
func (n *ReferenceDef) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := rd.Table.Accept(v)
|
||||
n = newNod.(*ReferenceDef)
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return rd, false
|
||||
return n, false
|
||||
}
|
||||
rd.Table = node.(*TableName)
|
||||
for i, val := range rd.IndexColNames {
|
||||
n.Table = node.(*TableName)
|
||||
for i, val := range n.IndexColNames {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return rd, false
|
||||
return n, false
|
||||
}
|
||||
rd.IndexColNames[i] = node.(*IndexColName)
|
||||
n.IndexColNames[i] = node.(*IndexColName)
|
||||
}
|
||||
return v.Leave(rd)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ColumnOptionType is the type for ColumnOption.
|
||||
@ -175,18 +183,20 @@ type ColumnOption struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (co *ColumnOption) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(co); skipChildren {
|
||||
return co, ok
|
||||
func (n *ColumnOption) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if co.Expr != nil {
|
||||
node, ok := co.Expr.Accept(v)
|
||||
n = newNod.(*ColumnOption)
|
||||
if n.Expr != nil {
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return co, false
|
||||
return n, false
|
||||
}
|
||||
co.Expr = node.(ExprNode)
|
||||
n.Expr = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(co)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ConstraintType is the type for Constraint.
|
||||
@ -220,25 +230,27 @@ type Constraint struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (tc *Constraint) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(tc); skipChildren {
|
||||
return tc, ok
|
||||
func (n *Constraint) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range tc.Keys {
|
||||
n = newNod.(*Constraint)
|
||||
for i, val := range n.Keys {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return tc, false
|
||||
return n, false
|
||||
}
|
||||
tc.Keys[i] = node.(*IndexColName)
|
||||
n.Keys[i] = node.(*IndexColName)
|
||||
}
|
||||
if tc.Refer != nil {
|
||||
node, ok := tc.Refer.Accept(v)
|
||||
if n.Refer != nil {
|
||||
node, ok := n.Refer.Accept(v)
|
||||
if !ok {
|
||||
return tc, false
|
||||
return n, false
|
||||
}
|
||||
tc.Refer = node.(*ReferenceDef)
|
||||
n.Refer = node.(*ReferenceDef)
|
||||
}
|
||||
return v.Leave(tc)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ColumnDef is used for parsing column definition from SQL.
|
||||
@ -251,23 +263,25 @@ type ColumnDef struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cd *ColumnDef) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cd); skipChildren {
|
||||
return cd, ok
|
||||
func (n *ColumnDef) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := cd.Name.Accept(v)
|
||||
n = newNod.(*ColumnDef)
|
||||
node, ok := n.Name.Accept(v)
|
||||
if !ok {
|
||||
return cd, false
|
||||
return n, false
|
||||
}
|
||||
cd.Name = node.(*ColumnName)
|
||||
for i, val := range cd.Options {
|
||||
n.Name = node.(*ColumnName)
|
||||
for i, val := range n.Options {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return cd, false
|
||||
return n, false
|
||||
}
|
||||
cd.Options[i] = node.(*ColumnOption)
|
||||
n.Options[i] = node.(*ColumnOption)
|
||||
}
|
||||
return v.Leave(cd)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// CreateTableStmt is a statement to create a table.
|
||||
@ -283,30 +297,32 @@ type CreateTableStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ct *CreateTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ct); skipChildren {
|
||||
return ct, ok
|
||||
func (n *CreateTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ct.Table.Accept(v)
|
||||
n = newNod.(*CreateTableStmt)
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return ct, false
|
||||
return n, false
|
||||
}
|
||||
ct.Table = node.(*TableName)
|
||||
for i, val := range ct.Cols {
|
||||
n.Table = node.(*TableName)
|
||||
for i, val := range n.Cols {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return ct, false
|
||||
return n, false
|
||||
}
|
||||
ct.Cols[i] = node.(*ColumnDef)
|
||||
n.Cols[i] = node.(*ColumnDef)
|
||||
}
|
||||
for i, val := range ct.Constraints {
|
||||
for i, val := range n.Constraints {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return ct, false
|
||||
return n, false
|
||||
}
|
||||
ct.Constraints[i] = node.(*Constraint)
|
||||
n.Constraints[i] = node.(*Constraint)
|
||||
}
|
||||
return v.Leave(ct)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DropTableStmt is a statement to drop one or more tables.
|
||||
@ -319,18 +335,20 @@ type DropTableStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (dt *DropTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(dt); skipChildren {
|
||||
return dt, ok
|
||||
func (n *DropTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range dt.Tables {
|
||||
n = newNod.(*DropTableStmt)
|
||||
for i, val := range n.Tables {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return dt, false
|
||||
return n, false
|
||||
}
|
||||
dt.Tables[i] = node.(*TableName)
|
||||
n.Tables[i] = node.(*TableName)
|
||||
}
|
||||
return v.Leave(dt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// CreateIndexStmt is a statement to create an index.
|
||||
@ -345,23 +363,25 @@ type CreateIndexStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ci *CreateIndexStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ci); skipChildren {
|
||||
return ci, ok
|
||||
func (n *CreateIndexStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ci.Table.Accept(v)
|
||||
n = newNod.(*CreateIndexStmt)
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return ci, false
|
||||
return n, false
|
||||
}
|
||||
ci.Table = node.(*TableName)
|
||||
for i, val := range ci.IndexColNames {
|
||||
n.Table = node.(*TableName)
|
||||
for i, val := range n.IndexColNames {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return ci, false
|
||||
return n, false
|
||||
}
|
||||
ci.IndexColNames[i] = node.(*IndexColName)
|
||||
n.IndexColNames[i] = node.(*IndexColName)
|
||||
}
|
||||
return v.Leave(ci)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DropIndexStmt is a statement to drop the index.
|
||||
@ -375,16 +395,18 @@ type DropIndexStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (di *DropIndexStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(di); skipChildren {
|
||||
return di, ok
|
||||
func (n *DropIndexStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := di.Table.Accept(v)
|
||||
n = newNod.(*DropIndexStmt)
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return di, false
|
||||
return n, false
|
||||
}
|
||||
di.Table = node.(*TableName)
|
||||
return v.Leave(di)
|
||||
n.Table = node.(*TableName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// TableOptionType is the type for TableOption
|
||||
@ -435,18 +457,20 @@ type ColumnPosition struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cp *ColumnPosition) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cp); skipChildren {
|
||||
return cp, ok
|
||||
func (n *ColumnPosition) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if cp.RelativeColumn != nil {
|
||||
node, ok := cp.RelativeColumn.Accept(v)
|
||||
n = newNod.(*ColumnPosition)
|
||||
if n.RelativeColumn != nil {
|
||||
node, ok := n.RelativeColumn.Accept(v)
|
||||
if !ok {
|
||||
return cp, false
|
||||
return n, false
|
||||
}
|
||||
cp.RelativeColumn = node.(*ColumnName)
|
||||
n.RelativeColumn = node.(*ColumnName)
|
||||
}
|
||||
return v.Leave(cp)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// AlterTableType is the type for AlterTableSpec.
|
||||
@ -474,44 +498,46 @@ type AlterTableSpec struct {
|
||||
Constraint *Constraint
|
||||
Options []*TableOption
|
||||
Column *ColumnDef
|
||||
ColumnName *ColumnName
|
||||
DropColumn *ColumnName
|
||||
Position *ColumnPosition
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (as *AlterTableSpec) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(as); skipChildren {
|
||||
return as, ok
|
||||
func (n *AlterTableSpec) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if as.Constraint != nil {
|
||||
node, ok := as.Constraint.Accept(v)
|
||||
n = newNod.(*AlterTableSpec)
|
||||
if n.Constraint != nil {
|
||||
node, ok := n.Constraint.Accept(v)
|
||||
if !ok {
|
||||
return as, false
|
||||
return n, false
|
||||
}
|
||||
as.Constraint = node.(*Constraint)
|
||||
n.Constraint = node.(*Constraint)
|
||||
}
|
||||
if as.Column != nil {
|
||||
node, ok := as.Column.Accept(v)
|
||||
if n.Column != nil {
|
||||
node, ok := n.Column.Accept(v)
|
||||
if !ok {
|
||||
return as, false
|
||||
return n, false
|
||||
}
|
||||
as.Column = node.(*ColumnDef)
|
||||
n.Column = node.(*ColumnDef)
|
||||
}
|
||||
if as.ColumnName != nil {
|
||||
node, ok := as.ColumnName.Accept(v)
|
||||
if n.DropColumn != nil {
|
||||
node, ok := n.DropColumn.Accept(v)
|
||||
if !ok {
|
||||
return as, false
|
||||
return n, false
|
||||
}
|
||||
as.ColumnName = node.(*ColumnName)
|
||||
n.DropColumn = node.(*ColumnName)
|
||||
}
|
||||
if as.Position != nil {
|
||||
node, ok := as.Position.Accept(v)
|
||||
if n.Position != nil {
|
||||
node, ok := n.Position.Accept(v)
|
||||
if !ok {
|
||||
return as, false
|
||||
return n, false
|
||||
}
|
||||
as.Position = node.(*ColumnPosition)
|
||||
n.Position = node.(*ColumnPosition)
|
||||
}
|
||||
return v.Leave(as)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// AlterTableStmt is a statement to change the structure of a table.
|
||||
@ -524,23 +550,25 @@ type AlterTableStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (at *AlterTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(at); skipChildren {
|
||||
return at, ok
|
||||
func (n *AlterTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := at.Table.Accept(v)
|
||||
n = newNod.(*AlterTableStmt)
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return at, false
|
||||
return n, false
|
||||
}
|
||||
at.Table = node.(*TableName)
|
||||
for i, val := range at.Specs {
|
||||
n.Table = node.(*TableName)
|
||||
for i, val := range n.Specs {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return at, false
|
||||
return n, false
|
||||
}
|
||||
at.Specs[i] = node.(*AlterTableSpec)
|
||||
n.Specs[i] = node.(*AlterTableSpec)
|
||||
}
|
||||
return v.Leave(at)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// TruncateTableStmt is a statement to empty a table completely.
|
||||
@ -552,14 +580,16 @@ type TruncateTableStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ts *TruncateTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ts); skipChildren {
|
||||
return ts, ok
|
||||
func (n *TruncateTableStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ts.Table.Accept(v)
|
||||
n = newNod.(*TruncateTableStmt)
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return ts, false
|
||||
return n, false
|
||||
}
|
||||
ts.Table = node.(*TableName)
|
||||
return v.Leave(ts)
|
||||
n.Table = node.(*TableName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
650
ast/dml.go
650
ast/dml.go
@ -22,6 +22,7 @@ var (
|
||||
_ DMLNode = &DeleteStmt{}
|
||||
_ DMLNode = &UpdateStmt{}
|
||||
_ DMLNode = &SelectStmt{}
|
||||
_ DMLNode = &UnionStmt{}
|
||||
_ Node = &Join{}
|
||||
_ Node = &TableName{}
|
||||
_ Node = &TableSource{}
|
||||
@ -55,34 +56,36 @@ type Join struct {
|
||||
// Tp represents join type.
|
||||
Tp JoinType
|
||||
// On represents join on condition.
|
||||
On ExprNode
|
||||
On *OnCondition
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (j *Join) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(j); skipChildren {
|
||||
return j, ok
|
||||
func (n *Join) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := j.Left.Accept(v)
|
||||
n = newNod.(*Join)
|
||||
node, ok := n.Left.Accept(v)
|
||||
if !ok {
|
||||
return j, false
|
||||
return n, false
|
||||
}
|
||||
j.Left = node.(ResultSetNode)
|
||||
if j.Right != nil {
|
||||
node, ok = j.Right.Accept(v)
|
||||
n.Left = node.(ResultSetNode)
|
||||
if n.Right != nil {
|
||||
node, ok = n.Right.Accept(v)
|
||||
if !ok {
|
||||
return j, false
|
||||
return n, false
|
||||
}
|
||||
j.Right = node.(ResultSetNode)
|
||||
n.Right = node.(ResultSetNode)
|
||||
}
|
||||
if j.On != nil {
|
||||
node, ok = j.On.Accept(v)
|
||||
if n.On != nil {
|
||||
node, ok = n.On.Accept(v)
|
||||
if !ok {
|
||||
return j, false
|
||||
return n, false
|
||||
}
|
||||
j.On = node.(ExprNode)
|
||||
n.On = node.(*OnCondition)
|
||||
}
|
||||
return v.Leave(j)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// TableName represents a table name.
|
||||
@ -98,11 +101,13 @@ type TableName struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (tr *TableName) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(tr); skipChildren {
|
||||
return tr, ok
|
||||
func (n *TableName) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(tr)
|
||||
n = newNod.(*TableName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// TableSource represents table source with a name.
|
||||
@ -110,7 +115,7 @@ type TableSource struct {
|
||||
node
|
||||
|
||||
// Source is the source of the data, can be a TableName,
|
||||
// a SubQuery, or a JoinNode.
|
||||
// a SelectStmt, a UnionStmt, or a JoinNode.
|
||||
Source ResultSetNode
|
||||
|
||||
// AsName is the as name of the table source.
|
||||
@ -118,47 +123,50 @@ type TableSource struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ts *TableSource) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ts); skipChildren {
|
||||
return ts, ok
|
||||
func (n *TableSource) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ts.Source.Accept(v)
|
||||
n = newNod.(*TableSource)
|
||||
node, ok := n.Source.Accept(v)
|
||||
if !ok {
|
||||
return ts, false
|
||||
return n, false
|
||||
}
|
||||
ts.Source = node.(ResultSetNode)
|
||||
return v.Leave(ts)
|
||||
n.Source = node.(ResultSetNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SetResultFields implements ResultSet interface.
|
||||
func (ts *TableSource) SetResultFields(rfs []*ResultField) {
|
||||
ts.Source.SetResultFields(rfs)
|
||||
}
|
||||
|
||||
// GetResultFields implements ResultSet interface.
|
||||
func (ts *TableSource) GetResultFields() []*ResultField {
|
||||
return ts.Source.GetResultFields()
|
||||
}
|
||||
|
||||
// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause.
|
||||
type UnionClause struct {
|
||||
// OnCondition represetns JOIN on condition.
|
||||
type OnCondition struct {
|
||||
node
|
||||
|
||||
Distinct bool
|
||||
Select *SelectStmt
|
||||
Expr ExprNode
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (uc *UnionClause) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(uc); skipChildren {
|
||||
return uc, ok
|
||||
func (n *OnCondition) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := uc.Select.Accept(v)
|
||||
n = newNod.(*OnCondition)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return uc, false
|
||||
return n, false
|
||||
}
|
||||
uc.Select = node.(*SelectStmt)
|
||||
return v.Leave(uc)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SetResultFields implements ResultSet interface.
|
||||
func (n *TableSource) SetResultFields(rfs []*ResultField) {
|
||||
n.Source.SetResultFields(rfs)
|
||||
}
|
||||
|
||||
// GetResultFields implements ResultSet interface.
|
||||
func (n *TableSource) GetResultFields() []*ResultField {
|
||||
return n.Source.GetResultFields()
|
||||
}
|
||||
|
||||
// SelectLockType is the lock type for SelectStmt.
|
||||
@ -175,22 +183,18 @@ const (
|
||||
type WildCardField struct {
|
||||
node
|
||||
|
||||
Table *TableName
|
||||
Table model.CIStr
|
||||
Schema model.CIStr
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (wf *WildCardField) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(wf); skipChildren {
|
||||
return wf, ok
|
||||
func (n *WildCardField) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if wf.Table != nil {
|
||||
node, ok := wf.Table.Accept(v)
|
||||
if !ok {
|
||||
return wf, false
|
||||
}
|
||||
wf.Table = node.(*TableName)
|
||||
}
|
||||
return v.Leave(wf)
|
||||
n = newNod.(*WildCardField)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SelectField represents fields in select statement.
|
||||
@ -199,6 +203,8 @@ func (wf *WildCardField) Accept(v Visitor) (Node, bool) {
|
||||
type SelectField struct {
|
||||
node
|
||||
|
||||
// Offset is used to get original text.
|
||||
Offset int
|
||||
// If WildCard is not nil, Expr will be nil.
|
||||
WildCard *WildCardField
|
||||
// If Expr is not nil, WildCard will be nil.
|
||||
@ -208,22 +214,70 @@ type SelectField struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (sf *SelectField) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(sf); skipChildren {
|
||||
return sf, ok
|
||||
func (n *SelectField) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if sf.Expr != nil {
|
||||
node, ok := sf.Expr.Accept(v)
|
||||
n = newNod.(*SelectField)
|
||||
if n.Expr != nil {
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return sf, false
|
||||
return n, false
|
||||
}
|
||||
sf.Expr = node.(ExprNode)
|
||||
n.Expr = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(sf)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// OrderByItem represents a single order by item.
|
||||
type OrderByItem struct {
|
||||
// FieldList represents field list in select statement.
|
||||
type FieldList struct {
|
||||
node
|
||||
|
||||
Fields []*SelectField
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *FieldList) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*FieldList)
|
||||
for i, val := range n.Fields {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Fields[i] = node.(*SelectField)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// TableRefsClause represents table references clause in dml statement.
|
||||
type TableRefsClause struct {
|
||||
node
|
||||
|
||||
TableRefs *Join
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *TableRefsClause) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*TableRefsClause)
|
||||
node, ok := n.TableRefs.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.TableRefs = node.(*Join)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ByItem represents an item in order by or group by.
|
||||
type ByItem struct {
|
||||
node
|
||||
|
||||
Expr ExprNode
|
||||
@ -231,131 +285,244 @@ type OrderByItem struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ob *OrderByItem) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ob); skipChildren {
|
||||
return ob, ok
|
||||
func (n *ByItem) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ob.Expr.Accept(v)
|
||||
n = newNod.(*ByItem)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return ob, false
|
||||
return n, false
|
||||
}
|
||||
ob.Expr = node.(ExprNode)
|
||||
return v.Leave(ob)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// GroupByClause represents group by clause.
|
||||
type GroupByClause struct {
|
||||
node
|
||||
Items []*ByItem
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *GroupByClause) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*GroupByClause)
|
||||
for i, val := range n.Items {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Items[i] = node.(*ByItem)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// HavingClause represents having clause.
|
||||
type HavingClause struct {
|
||||
node
|
||||
Expr ExprNode
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *HavingClause) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*HavingClause)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// OrderByClause represents order by clause.
|
||||
type OrderByClause struct {
|
||||
node
|
||||
Items []*ByItem
|
||||
ForUnion bool
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *OrderByClause) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*OrderByClause)
|
||||
for i, val := range n.Items {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Items[i] = node.(*ByItem)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SelectStmt represents the select query node.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/select.html
|
||||
type SelectStmt struct {
|
||||
dmlNode
|
||||
resultSetNode
|
||||
|
||||
// Distinct represents if the select has distinct option.
|
||||
Distinct bool
|
||||
// Fields is the select expression list.
|
||||
Fields []*SelectField
|
||||
// From is the from clause of the query.
|
||||
From *Join
|
||||
From *TableRefsClause
|
||||
// Where is the where clause in select statement.
|
||||
Where ExprNode
|
||||
// Fields is the select expression list.
|
||||
Fields *FieldList
|
||||
// GroupBy is the group by expression list.
|
||||
GroupBy []ExprNode
|
||||
GroupBy *GroupByClause
|
||||
// Having is the having condition.
|
||||
Having ExprNode
|
||||
// OrderBy is the odering expression list.
|
||||
OrderBy []*OrderByItem
|
||||
Having *HavingClause
|
||||
// OrderBy is the ordering expression list.
|
||||
OrderBy *OrderByClause
|
||||
// Limit is the limit clause.
|
||||
Limit *Limit
|
||||
// Lock is the lock type
|
||||
LockTp SelectLockType
|
||||
|
||||
// Union clauses.
|
||||
Unions []*UnionClause
|
||||
// Order by for union select.
|
||||
UnionOrderBy []*OrderByItem
|
||||
// Limit for union select.
|
||||
UnionLimit *Limit
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (sn *SelectStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(sn); skipChildren {
|
||||
return sn, ok
|
||||
func (n *SelectStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range sn.Fields {
|
||||
node, ok := val.Accept(v)
|
||||
n = newNod.(*SelectStmt)
|
||||
|
||||
if n.From != nil {
|
||||
node, ok := n.From.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.Fields[i] = node.(*SelectField)
|
||||
}
|
||||
if sn.From != nil {
|
||||
node, ok := sn.From.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
}
|
||||
sn.From = node.(*Join)
|
||||
n.From = node.(*TableRefsClause)
|
||||
}
|
||||
|
||||
if sn.Where != nil {
|
||||
node, ok := sn.Where.Accept(v)
|
||||
if n.Where != nil {
|
||||
node, ok := n.Where.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.Where = node.(ExprNode)
|
||||
n.Where = node.(ExprNode)
|
||||
}
|
||||
|
||||
for i, val := range sn.GroupBy {
|
||||
node, ok := val.Accept(v)
|
||||
if n.Fields != nil {
|
||||
node, ok := n.Fields.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.GroupBy[i] = node.(ExprNode)
|
||||
}
|
||||
if sn.Having != nil {
|
||||
node, ok := sn.Having.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
}
|
||||
sn.Having = node.(ExprNode)
|
||||
n.Fields = node.(*FieldList)
|
||||
}
|
||||
|
||||
for i, val := range sn.OrderBy {
|
||||
node, ok := val.Accept(v)
|
||||
if n.GroupBy != nil {
|
||||
node, ok := n.GroupBy.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.OrderBy[i] = node.(*OrderByItem)
|
||||
n.GroupBy = node.(*GroupByClause)
|
||||
}
|
||||
|
||||
if sn.Limit != nil {
|
||||
node, ok := sn.Limit.Accept(v)
|
||||
if n.Having != nil {
|
||||
node, ok := n.Having.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.Limit = node.(*Limit)
|
||||
n.Having = node.(*HavingClause)
|
||||
}
|
||||
|
||||
for i, val := range sn.Unions {
|
||||
if n.OrderBy != nil {
|
||||
node, ok := n.OrderBy.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.OrderBy = node.(*OrderByClause)
|
||||
}
|
||||
|
||||
if n.Limit != nil {
|
||||
node, ok := n.Limit.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Limit = node.(*Limit)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// UnionClause represents a single "UNION SELECT ..." or "UNION (SELECT ...)" clause.
|
||||
type UnionClause struct {
|
||||
node
|
||||
|
||||
Distinct bool
|
||||
Select *SelectStmt
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *UnionClause) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*UnionClause)
|
||||
node, ok := n.Select.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Select = node.(*SelectStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// UnionStmt represents "union statement"
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/union.html
|
||||
type UnionStmt struct {
|
||||
dmlNode
|
||||
resultSetNode
|
||||
|
||||
Distinct bool
|
||||
Selects []*SelectStmt
|
||||
OrderBy *OrderByClause
|
||||
Limit *Limit
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *UnionStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*UnionStmt)
|
||||
for i, val := range n.Selects {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.Unions[i] = node.(*UnionClause)
|
||||
n.Selects[i] = node.(*SelectStmt)
|
||||
}
|
||||
for i, val := range sn.UnionOrderBy {
|
||||
node, ok := val.Accept(v)
|
||||
if n.OrderBy != nil {
|
||||
node, ok := n.OrderBy.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.UnionOrderBy[i] = node.(*OrderByItem)
|
||||
n.OrderBy = node.(*OrderByClause)
|
||||
}
|
||||
if sn.UnionLimit != nil {
|
||||
node, ok := sn.UnionLimit.Accept(v)
|
||||
if n.Limit != nil {
|
||||
node, ok := n.Limit.Accept(v)
|
||||
if !ok {
|
||||
return sn, false
|
||||
return n, false
|
||||
}
|
||||
sn.UnionLimit = node.(*Limit)
|
||||
n.Limit = node.(*Limit)
|
||||
}
|
||||
return v.Leave(sn)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// Assignment is the expression for assignment, like a = 1.
|
||||
@ -368,21 +535,23 @@ type Assignment struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (as *Assignment) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(as); skipChildren {
|
||||
return as, ok
|
||||
func (n *Assignment) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := as.Column.Accept(v)
|
||||
n = newNod.(*Assignment)
|
||||
node, ok := n.Column.Accept(v)
|
||||
if !ok {
|
||||
return as, false
|
||||
return n, false
|
||||
}
|
||||
as.Column = node.(*ColumnName)
|
||||
node, ok = as.Expr.Accept(v)
|
||||
n.Column = node.(*ColumnName)
|
||||
node, ok = n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return as, false
|
||||
return n, false
|
||||
}
|
||||
as.Expr = node.(ExprNode)
|
||||
return v.Leave(as)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// Priority const values.
|
||||
@ -399,58 +568,68 @@ const (
|
||||
type InsertStmt struct {
|
||||
dmlNode
|
||||
|
||||
Replace bool
|
||||
Table *TableRefsClause
|
||||
Columns []*ColumnName
|
||||
Lists [][]ExprNode
|
||||
Table *TableName
|
||||
Setlist []*Assignment
|
||||
Priority int
|
||||
OnDuplicate []*Assignment
|
||||
Select *SelectStmt
|
||||
Select ResultSetNode
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (in *InsertStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(in); skipChildren {
|
||||
return in, ok
|
||||
func (n *InsertStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range in.Columns {
|
||||
n = newNod.(*InsertStmt)
|
||||
|
||||
if n.Select != nil {
|
||||
node, ok := n.Select.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Select = node.(ResultSetNode)
|
||||
}
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Table = node.(*TableRefsClause)
|
||||
|
||||
for i, val := range n.Columns {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return in, false
|
||||
return n, false
|
||||
}
|
||||
in.Columns[i] = node.(*ColumnName)
|
||||
n.Columns[i] = node.(*ColumnName)
|
||||
}
|
||||
for i, list := range in.Lists {
|
||||
for i, list := range n.Lists {
|
||||
for j, val := range list {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return in, false
|
||||
return n, false
|
||||
}
|
||||
in.Lists[i][j] = node.(ExprNode)
|
||||
n.Lists[i][j] = node.(ExprNode)
|
||||
}
|
||||
}
|
||||
for i, val := range in.Setlist {
|
||||
for i, val := range n.Setlist {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return in, false
|
||||
return n, false
|
||||
}
|
||||
in.Setlist[i] = node.(*Assignment)
|
||||
n.Setlist[i] = node.(*Assignment)
|
||||
}
|
||||
for i, val := range in.OnDuplicate {
|
||||
for i, val := range n.OnDuplicate {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return in, false
|
||||
return n, false
|
||||
}
|
||||
in.OnDuplicate[i] = node.(*Assignment)
|
||||
n.OnDuplicate[i] = node.(*Assignment)
|
||||
}
|
||||
if in.Select != nil {
|
||||
node, ok := in.Select.Accept(v)
|
||||
if !ok {
|
||||
return in, false
|
||||
}
|
||||
in.Select = node.(*SelectStmt)
|
||||
}
|
||||
return v.Leave(in)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DeleteStmt is a statement to delete rows from table.
|
||||
@ -458,10 +637,12 @@ func (in *InsertStmt) Accept(v Visitor) (Node, bool) {
|
||||
type DeleteStmt struct {
|
||||
dmlNode
|
||||
|
||||
TableRefs *Join
|
||||
// Used in both single table and multiple table delete statement.
|
||||
TableRefs *TableRefsClause
|
||||
// Only used in multiple table delete statement.
|
||||
Tables []*TableName
|
||||
Where ExprNode
|
||||
Order []*OrderByItem
|
||||
Order *OrderByClause
|
||||
Limit *Limit
|
||||
LowPriority bool
|
||||
Ignore bool
|
||||
@ -471,47 +652,49 @@ type DeleteStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (de *DeleteStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(de); skipChildren {
|
||||
return de, ok
|
||||
func (n *DeleteStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*DeleteStmt)
|
||||
|
||||
node, ok := de.TableRefs.Accept(v)
|
||||
node, ok := n.TableRefs.Accept(v)
|
||||
if !ok {
|
||||
return de, false
|
||||
return n, false
|
||||
}
|
||||
de.TableRefs = node.(*Join)
|
||||
n.TableRefs = node.(*TableRefsClause)
|
||||
|
||||
for i, val := range de.Tables {
|
||||
for i, val := range n.Tables {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return de, false
|
||||
return n, false
|
||||
}
|
||||
de.Tables[i] = node.(*TableName)
|
||||
n.Tables[i] = node.(*TableName)
|
||||
}
|
||||
|
||||
if de.Where != nil {
|
||||
node, ok = de.Where.Accept(v)
|
||||
if n.Where != nil {
|
||||
node, ok = n.Where.Accept(v)
|
||||
if !ok {
|
||||
return de, false
|
||||
return n, false
|
||||
}
|
||||
de.Where = node.(ExprNode)
|
||||
n.Where = node.(ExprNode)
|
||||
}
|
||||
|
||||
for i, val := range de.Order {
|
||||
node, ok = val.Accept(v)
|
||||
if n.Order != nil {
|
||||
node, ok = n.Order.Accept(v)
|
||||
if !ok {
|
||||
return de, false
|
||||
return n, false
|
||||
}
|
||||
de.Order[i] = node.(*OrderByItem)
|
||||
n.Order = node.(*OrderByClause)
|
||||
}
|
||||
|
||||
node, ok = de.Limit.Accept(v)
|
||||
if !ok {
|
||||
return de, false
|
||||
if n.Limit != nil {
|
||||
node, ok = n.Limit.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Limit = node.(*Limit)
|
||||
}
|
||||
de.Limit = node.(*Limit)
|
||||
return v.Leave(de)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// UpdateStmt is a statement to update columns of existing rows in tables with new values.
|
||||
@ -519,10 +702,10 @@ func (de *DeleteStmt) Accept(v Visitor) (Node, bool) {
|
||||
type UpdateStmt struct {
|
||||
dmlNode
|
||||
|
||||
TableRefs *Join
|
||||
TableRefs *TableRefsClause
|
||||
List []*Assignment
|
||||
Where ExprNode
|
||||
Order []*OrderByItem
|
||||
Order *OrderByClause
|
||||
Limit *Limit
|
||||
LowPriority bool
|
||||
Ignore bool
|
||||
@ -530,43 +713,46 @@ type UpdateStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (up *UpdateStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(up); skipChildren {
|
||||
return up, ok
|
||||
func (n *UpdateStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := up.TableRefs.Accept(v)
|
||||
n = newNod.(*UpdateStmt)
|
||||
node, ok := n.TableRefs.Accept(v)
|
||||
if !ok {
|
||||
return up, false
|
||||
return n, false
|
||||
}
|
||||
up.TableRefs = node.(*Join)
|
||||
for i, val := range up.List {
|
||||
n.TableRefs = node.(*TableRefsClause)
|
||||
for i, val := range n.List {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return up, false
|
||||
return n, false
|
||||
}
|
||||
up.List[i] = node.(*Assignment)
|
||||
n.List[i] = node.(*Assignment)
|
||||
}
|
||||
if up.Where != nil {
|
||||
node, ok = up.Where.Accept(v)
|
||||
if n.Where != nil {
|
||||
node, ok = n.Where.Accept(v)
|
||||
if !ok {
|
||||
return up, false
|
||||
return n, false
|
||||
}
|
||||
up.Where = node.(ExprNode)
|
||||
n.Where = node.(ExprNode)
|
||||
}
|
||||
|
||||
for i, val := range up.Order {
|
||||
node, ok = val.Accept(v)
|
||||
if n.Order != nil {
|
||||
node, ok = n.Order.Accept(v)
|
||||
if !ok {
|
||||
return up, false
|
||||
return n, false
|
||||
}
|
||||
up.Order[i] = node.(*OrderByItem)
|
||||
n.Order = node.(*OrderByClause)
|
||||
}
|
||||
node, ok = up.Limit.Accept(v)
|
||||
if !ok {
|
||||
return up, false
|
||||
if n.Limit != nil {
|
||||
node, ok = n.Limit.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Limit = node.(*Limit)
|
||||
}
|
||||
up.Limit = node.(*Limit)
|
||||
return v.Leave(up)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// Limit is the limit clause.
|
||||
@ -578,9 +764,11 @@ type Limit struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (l *Limit) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(l); skipChildren {
|
||||
return l, ok
|
||||
func (n *Limit) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(l)
|
||||
n = newNod.(*Limit)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
@ -14,8 +14,11 @@
|
||||
package ast
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -26,6 +29,7 @@ var (
|
||||
_ ExprNode = &CaseExpr{}
|
||||
_ ExprNode = &SubqueryExpr{}
|
||||
_ ExprNode = &CompareSubqueryExpr{}
|
||||
_ Node = &ColumnName{}
|
||||
_ ExprNode = &ColumnNameExpr{}
|
||||
_ ExprNode = &DefaultExpr{}
|
||||
_ ExprNode = &IdentifierExpr{}
|
||||
@ -47,21 +51,58 @@ var (
|
||||
// ValueExpr is the simple value expression.
|
||||
type ValueExpr struct {
|
||||
exprNode
|
||||
// Val is the literal value.
|
||||
Val interface{}
|
||||
}
|
||||
|
||||
// NewValueExpr creates a ValueExpr with value, and sets default field type.
|
||||
func NewValueExpr(value interface{}) *ValueExpr {
|
||||
ve := &ValueExpr{}
|
||||
ve.Data = types.RawData(value)
|
||||
// TODO: make it more precise.
|
||||
switch value.(type) {
|
||||
case nil:
|
||||
ve.Type = types.NewFieldType(mysql.TypeNull)
|
||||
case bool, int64:
|
||||
ve.Type = types.NewFieldType(mysql.TypeLonglong)
|
||||
case uint64:
|
||||
ve.Type = types.NewFieldType(mysql.TypeLonglong)
|
||||
ve.Type.Flag |= mysql.UnsignedFlag
|
||||
case string, UnquoteString:
|
||||
ve.Type = types.NewFieldType(mysql.TypeVarchar)
|
||||
ve.Type.Charset = mysql.DefaultCharset
|
||||
ve.Type.Collate = mysql.DefaultCollationName
|
||||
case float64:
|
||||
ve.Type = types.NewFieldType(mysql.TypeDouble)
|
||||
case []byte:
|
||||
ve.Type = types.NewFieldType(mysql.TypeBlob)
|
||||
ve.Type.Charset = "binary"
|
||||
ve.Type.Collate = "binary"
|
||||
case mysql.Bit:
|
||||
ve.Type = types.NewFieldType(mysql.TypeBit)
|
||||
case mysql.Hex:
|
||||
ve.Type = types.NewFieldType(mysql.TypeVarchar)
|
||||
ve.Type.Charset = "binary"
|
||||
ve.Type.Collate = "binary"
|
||||
case *types.DataItem:
|
||||
ve.Type = value.(*types.DataItem).Type
|
||||
default:
|
||||
panic(fmt.Sprintf("illegal literal value type:%T", value))
|
||||
}
|
||||
return ve
|
||||
}
|
||||
|
||||
// IsStatic implements ExprNode interface.
|
||||
func (val *ValueExpr) IsStatic() bool {
|
||||
func (n *ValueExpr) IsStatic() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Accept implements Node interface.
|
||||
func (val *ValueExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(val); skipChildren {
|
||||
return val, ok
|
||||
func (n *ValueExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(val)
|
||||
n = newNod.(*ValueExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// BetweenExpr is for "between and" or "not between and" expression.
|
||||
@ -78,35 +119,37 @@ type BetweenExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node interface.
|
||||
func (b *BetweenExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(b); skipChildren {
|
||||
return b, ok
|
||||
func (n *BetweenExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*BetweenExpr)
|
||||
|
||||
node, ok := b.Expr.Accept(v)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return b, false
|
||||
return n, false
|
||||
}
|
||||
b.Expr = node.(ExprNode)
|
||||
n.Expr = node.(ExprNode)
|
||||
|
||||
node, ok = b.Left.Accept(v)
|
||||
node, ok = n.Left.Accept(v)
|
||||
if !ok {
|
||||
return b, false
|
||||
return n, false
|
||||
}
|
||||
b.Left = node.(ExprNode)
|
||||
n.Left = node.(ExprNode)
|
||||
|
||||
node, ok = b.Right.Accept(v)
|
||||
node, ok = n.Right.Accept(v)
|
||||
if !ok {
|
||||
return b, false
|
||||
return n, false
|
||||
}
|
||||
b.Right = node.(ExprNode)
|
||||
n.Right = node.(ExprNode)
|
||||
|
||||
return v.Leave(b)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (b *BetweenExpr) IsStatic() bool {
|
||||
return b.Expr.IsStatic() && b.Left.IsStatic() && b.Right.IsStatic()
|
||||
func (n *BetweenExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic() && n.Left.IsStatic() && n.Right.IsStatic()
|
||||
}
|
||||
|
||||
// BinaryOperationExpr is for binary operation like 1 + 1, 1 - 1, etc.
|
||||
@ -121,29 +164,31 @@ type BinaryOperationExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node interface.
|
||||
func (o *BinaryOperationExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(o); skipChildren {
|
||||
return o, ok
|
||||
func (n *BinaryOperationExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*BinaryOperationExpr)
|
||||
|
||||
node, ok := o.L.Accept(v)
|
||||
node, ok := n.L.Accept(v)
|
||||
if !ok {
|
||||
return o, false
|
||||
return n, false
|
||||
}
|
||||
o.L = node.(ExprNode)
|
||||
n.L = node.(ExprNode)
|
||||
|
||||
node, ok = o.R.Accept(v)
|
||||
node, ok = n.R.Accept(v)
|
||||
if !ok {
|
||||
return o, false
|
||||
return n, false
|
||||
}
|
||||
o.R = node.(ExprNode)
|
||||
n.R = node.(ExprNode)
|
||||
|
||||
return v.Leave(o)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (o *BinaryOperationExpr) IsStatic() bool {
|
||||
return o.L.IsStatic() && o.R.IsStatic()
|
||||
func (n *BinaryOperationExpr) IsStatic() bool {
|
||||
return n.L.IsStatic() && n.R.IsStatic()
|
||||
}
|
||||
|
||||
// WhenClause is the when clause in Case expression for "when condition then result".
|
||||
@ -156,27 +201,29 @@ type WhenClause struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (w *WhenClause) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(w); skipChildren {
|
||||
return w, ok
|
||||
func (n *WhenClause) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := w.Expr.Accept(v)
|
||||
n = newNod.(*WhenClause)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return w, false
|
||||
return n, false
|
||||
}
|
||||
w.Expr = node.(ExprNode)
|
||||
n.Expr = node.(ExprNode)
|
||||
|
||||
node, ok = w.Result.Accept(v)
|
||||
node, ok = n.Result.Accept(v)
|
||||
if !ok {
|
||||
return w, false
|
||||
return n, false
|
||||
}
|
||||
w.Result = node.(ExprNode)
|
||||
return v.Leave(w)
|
||||
n.Result = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (w *WhenClause) IsStatic() bool {
|
||||
return w.Expr.IsStatic() && w.Result.IsStatic()
|
||||
func (n *WhenClause) IsStatic() bool {
|
||||
return n.Expr.IsStatic() && n.Result.IsStatic()
|
||||
}
|
||||
|
||||
// CaseExpr is the case expression.
|
||||
@ -191,45 +238,47 @@ type CaseExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (f *CaseExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(f); skipChildren {
|
||||
return f, ok
|
||||
func (n *CaseExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if f.Value != nil {
|
||||
node, ok := f.Value.Accept(v)
|
||||
n = newNod.(*CaseExpr)
|
||||
if n.Value != nil {
|
||||
node, ok := n.Value.Accept(v)
|
||||
if !ok {
|
||||
return f, false
|
||||
return n, false
|
||||
}
|
||||
f.Value = node.(ExprNode)
|
||||
n.Value = node.(ExprNode)
|
||||
}
|
||||
for i, val := range f.WhenClauses {
|
||||
for i, val := range n.WhenClauses {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return f, false
|
||||
return n, false
|
||||
}
|
||||
f.WhenClauses[i] = node.(*WhenClause)
|
||||
n.WhenClauses[i] = node.(*WhenClause)
|
||||
}
|
||||
if f.ElseClause != nil {
|
||||
node, ok := f.ElseClause.Accept(v)
|
||||
if n.ElseClause != nil {
|
||||
node, ok := n.ElseClause.Accept(v)
|
||||
if !ok {
|
||||
return f, false
|
||||
return n, false
|
||||
}
|
||||
f.ElseClause = node.(ExprNode)
|
||||
n.ElseClause = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(f)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (f *CaseExpr) IsStatic() bool {
|
||||
if f.Value != nil && !f.Value.IsStatic() {
|
||||
func (n *CaseExpr) IsStatic() bool {
|
||||
if n.Value != nil && !n.Value.IsStatic() {
|
||||
return false
|
||||
}
|
||||
for _, w := range f.WhenClauses {
|
||||
for _, w := range n.WhenClauses {
|
||||
if !w.IsStatic() {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if f.ElseClause != nil && !f.ElseClause.IsStatic() {
|
||||
if n.ElseClause != nil && !n.ElseClause.IsStatic() {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
@ -239,30 +288,32 @@ func (f *CaseExpr) IsStatic() bool {
|
||||
type SubqueryExpr struct {
|
||||
exprNode
|
||||
// Query is the query SelectNode.
|
||||
Query *SelectStmt
|
||||
Query ResultSetNode
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (sq *SubqueryExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(sq); skipChildren {
|
||||
return sq, ok
|
||||
func (n *SubqueryExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := sq.Query.Accept(v)
|
||||
n = newNod.(*SubqueryExpr)
|
||||
node, ok := n.Query.Accept(v)
|
||||
if !ok {
|
||||
return sq, false
|
||||
return n, false
|
||||
}
|
||||
sq.Query = node.(*SelectStmt)
|
||||
return v.Leave(sq)
|
||||
n.Query = node.(ResultSetNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SetResultFields implements ResultSet interface.
|
||||
func (sq *SubqueryExpr) SetResultFields(rfs []*ResultField) {
|
||||
sq.Query.SetResultFields(rfs)
|
||||
func (n *SubqueryExpr) SetResultFields(rfs []*ResultField) {
|
||||
n.Query.SetResultFields(rfs)
|
||||
}
|
||||
|
||||
// GetResultFields implements ResultSet interface.
|
||||
func (sq *SubqueryExpr) GetResultFields() []*ResultField {
|
||||
return sq.Query.GetResultFields()
|
||||
func (n *SubqueryExpr) GetResultFields() []*ResultField {
|
||||
return n.Query.GetResultFields()
|
||||
}
|
||||
|
||||
// CompareSubqueryExpr is the expression for "expr cmp (select ...)".
|
||||
@ -282,21 +333,23 @@ type CompareSubqueryExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cs *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cs); skipChildren {
|
||||
return cs, ok
|
||||
func (n *CompareSubqueryExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := cs.L.Accept(v)
|
||||
n = newNod.(*CompareSubqueryExpr)
|
||||
node, ok := n.L.Accept(v)
|
||||
if !ok {
|
||||
return cs, false
|
||||
return n, false
|
||||
}
|
||||
cs.L = node.(ExprNode)
|
||||
node, ok = cs.R.Accept(v)
|
||||
n.L = node.(ExprNode)
|
||||
node, ok = n.R.Accept(v)
|
||||
if !ok {
|
||||
return cs, false
|
||||
return n, false
|
||||
}
|
||||
cs.R = node.(*SubqueryExpr)
|
||||
return v.Leave(cs)
|
||||
n.R = node.(*SubqueryExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ColumnName represents column name.
|
||||
@ -312,11 +365,13 @@ type ColumnName struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cn *ColumnName) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cn); skipChildren {
|
||||
return cn, ok
|
||||
func (n *ColumnName) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(cn)
|
||||
n = newNod.(*ColumnName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ColumnNameExpr represents a column name expression.
|
||||
@ -328,16 +383,18 @@ type ColumnNameExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cr *ColumnNameExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cr); skipChildren {
|
||||
return cr, ok
|
||||
func (n *ColumnNameExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := cr.Name.Accept(v)
|
||||
n = newNod.(*ColumnNameExpr)
|
||||
node, ok := n.Name.Accept(v)
|
||||
if !ok {
|
||||
return cr, false
|
||||
return n, false
|
||||
}
|
||||
cr.Name = node.(*ColumnName)
|
||||
return v.Leave(cr)
|
||||
n.Name = node.(*ColumnName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DefaultExpr is the default expression using default value for a column.
|
||||
@ -348,18 +405,20 @@ type DefaultExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (d *DefaultExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(d); skipChildren {
|
||||
return d, ok
|
||||
func (n *DefaultExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if d.Name != nil {
|
||||
node, ok := d.Name.Accept(v)
|
||||
n = newNod.(*DefaultExpr)
|
||||
if n.Name != nil {
|
||||
node, ok := n.Name.Accept(v)
|
||||
if !ok {
|
||||
return d, false
|
||||
return n, false
|
||||
}
|
||||
d.Name = node.(*ColumnName)
|
||||
n.Name = node.(*ColumnName)
|
||||
}
|
||||
return v.Leave(d)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IdentifierExpr represents an identifier expression.
|
||||
@ -370,11 +429,13 @@ type IdentifierExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (i *IdentifierExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(i); skipChildren {
|
||||
return i, ok
|
||||
func (n *IdentifierExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(i)
|
||||
n = newNod.(*IdentifierExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ExistsSubqueryExpr is the expression for "exists (select ...)".
|
||||
@ -386,16 +447,18 @@ type ExistsSubqueryExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (es *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(es); skipChildren {
|
||||
return es, ok
|
||||
func (n *ExistsSubqueryExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := es.Sel.Accept(v)
|
||||
n = newNod.(*ExistsSubqueryExpr)
|
||||
node, ok := n.Sel.Accept(v)
|
||||
if !ok {
|
||||
return es, false
|
||||
return n, false
|
||||
}
|
||||
es.Sel = node.(*SubqueryExpr)
|
||||
return v.Leave(es)
|
||||
n.Sel = node.(*SubqueryExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// PatternInExpr is the expression for in operator, like "expr in (1, 2, 3)" or "expr in (select c from t)".
|
||||
@ -412,30 +475,32 @@ type PatternInExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (pi *PatternInExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(pi); skipChildren {
|
||||
return pi, ok
|
||||
func (n *PatternInExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := pi.Expr.Accept(v)
|
||||
n = newNod.(*PatternInExpr)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return pi, false
|
||||
return n, false
|
||||
}
|
||||
pi.Expr = node.(ExprNode)
|
||||
for i, val := range pi.List {
|
||||
n.Expr = node.(ExprNode)
|
||||
for i, val := range n.List {
|
||||
node, ok = val.Accept(v)
|
||||
if !ok {
|
||||
return pi, false
|
||||
return n, false
|
||||
}
|
||||
pi.List[i] = node.(ExprNode)
|
||||
n.List[i] = node.(ExprNode)
|
||||
}
|
||||
if pi.Sel != nil {
|
||||
node, ok = pi.Sel.Accept(v)
|
||||
if n.Sel != nil {
|
||||
node, ok = n.Sel.Accept(v)
|
||||
if !ok {
|
||||
return pi, false
|
||||
return n, false
|
||||
}
|
||||
pi.Sel = node.(*SubqueryExpr)
|
||||
n.Sel = node.(*SubqueryExpr)
|
||||
}
|
||||
return v.Leave(pi)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsNullExpr is the expression for null check.
|
||||
@ -448,21 +513,23 @@ type IsNullExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (is *IsNullExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(is); skipChildren {
|
||||
return is, ok
|
||||
func (n *IsNullExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := is.Expr.Accept(v)
|
||||
n = newNod.(*IsNullExpr)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return is, false
|
||||
return n, false
|
||||
}
|
||||
is.Expr = node.(ExprNode)
|
||||
return v.Leave(is)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (is *IsNullExpr) IsStatic() bool {
|
||||
return is.Expr.IsStatic()
|
||||
func (n *IsNullExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic()
|
||||
}
|
||||
|
||||
// IsTruthExpr is the expression for true/false check.
|
||||
@ -477,21 +544,23 @@ type IsTruthExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (is *IsTruthExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(is); skipChildren {
|
||||
return is, ok
|
||||
func (n *IsTruthExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := is.Expr.Accept(v)
|
||||
n = newNod.(*IsTruthExpr)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return is, false
|
||||
return n, false
|
||||
}
|
||||
is.Expr = node.(ExprNode)
|
||||
return v.Leave(is)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (is *IsTruthExpr) IsStatic() bool {
|
||||
return is.Expr.IsStatic()
|
||||
func (n *IsTruthExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic()
|
||||
}
|
||||
|
||||
// PatternLikeExpr is the expression for like operator, e.g, expr like "%123%"
|
||||
@ -508,40 +577,49 @@ type PatternLikeExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (pl *PatternLikeExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(pl); skipChildren {
|
||||
return pl, ok
|
||||
func (n *PatternLikeExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := pl.Expr.Accept(v)
|
||||
if !ok {
|
||||
return pl, false
|
||||
n = newNod.(*PatternLikeExpr)
|
||||
if n.Expr != nil {
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Expr = node.(ExprNode)
|
||||
}
|
||||
pl.Expr = node.(ExprNode)
|
||||
node, ok = pl.Pattern.Accept(v)
|
||||
if !ok {
|
||||
return pl, false
|
||||
if n.Pattern != nil {
|
||||
node, ok := n.Pattern.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Pattern = node.(ExprNode)
|
||||
}
|
||||
pl.Pattern = node.(ExprNode)
|
||||
return v.Leave(pl)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (pl *PatternLikeExpr) IsStatic() bool {
|
||||
return pl.Expr.IsStatic() && pl.Pattern.IsStatic()
|
||||
func (n *PatternLikeExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic() && n.Pattern.IsStatic()
|
||||
}
|
||||
|
||||
// ParamMarkerExpr expresion holds a place for another expression.
|
||||
// Used in parsing prepare statement.
|
||||
type ParamMarkerExpr struct {
|
||||
exprNode
|
||||
Offset int
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (pm *ParamMarkerExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(pm); skipChildren {
|
||||
return pm, ok
|
||||
func (n *ParamMarkerExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(pm)
|
||||
n = newNod.(*ParamMarkerExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ParenthesesExpr is the parentheses expression.
|
||||
@ -552,23 +630,25 @@ type ParenthesesExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (p *ParenthesesExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(p); skipChildren {
|
||||
return p, ok
|
||||
func (n *ParenthesesExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if p.Expr != nil {
|
||||
node, ok := p.Expr.Accept(v)
|
||||
n = newNod.(*ParenthesesExpr)
|
||||
if n.Expr != nil {
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return p, false
|
||||
return n, false
|
||||
}
|
||||
p.Expr = node.(ExprNode)
|
||||
n.Expr = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(p)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (p *ParenthesesExpr) IsStatic() bool {
|
||||
return p.Expr.IsStatic()
|
||||
func (n *ParenthesesExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic()
|
||||
}
|
||||
|
||||
// PositionExpr is the expression for order by and group by position.
|
||||
@ -583,16 +663,18 @@ type PositionExpr struct {
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (p *PositionExpr) IsStatic() bool {
|
||||
func (n *PositionExpr) IsStatic() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (p *PositionExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(p); skipChildren {
|
||||
return p, ok
|
||||
func (n *PositionExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(p)
|
||||
n = newNod.(*PositionExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// PatternRegexpExpr is the pattern expression for pattern match.
|
||||
@ -607,26 +689,28 @@ type PatternRegexpExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (p *PatternRegexpExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(p); skipChildren {
|
||||
return p, ok
|
||||
func (n *PatternRegexpExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := p.Expr.Accept(v)
|
||||
n = newNod.(*PatternRegexpExpr)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return p, false
|
||||
return n, false
|
||||
}
|
||||
p.Expr = node.(ExprNode)
|
||||
node, ok = p.Pattern.Accept(v)
|
||||
n.Expr = node.(ExprNode)
|
||||
node, ok = n.Pattern.Accept(v)
|
||||
if !ok {
|
||||
return p, false
|
||||
return n, false
|
||||
}
|
||||
p.Pattern = node.(ExprNode)
|
||||
return v.Leave(p)
|
||||
n.Pattern = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (p *PatternRegexpExpr) IsStatic() bool {
|
||||
return p.Expr.IsStatic() && p.Pattern.IsStatic()
|
||||
func (n *PatternRegexpExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic() && n.Pattern.IsStatic()
|
||||
}
|
||||
|
||||
// RowExpr is the expression for row constructor.
|
||||
@ -638,23 +722,25 @@ type RowExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (r *RowExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(r); skipChildren {
|
||||
return r, ok
|
||||
func (n *RowExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range r.Values {
|
||||
n = newNod.(*RowExpr)
|
||||
for i, val := range n.Values {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return r, false
|
||||
return n, false
|
||||
}
|
||||
r.Values[i] = node.(ExprNode)
|
||||
n.Values[i] = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(r)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (r *RowExpr) IsStatic() bool {
|
||||
for _, v := range r.Values {
|
||||
func (n *RowExpr) IsStatic() bool {
|
||||
for _, v := range n.Values {
|
||||
if !v.IsStatic() {
|
||||
return false
|
||||
}
|
||||
@ -672,21 +758,23 @@ type UnaryOperationExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (u *UnaryOperationExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(u); skipChildren {
|
||||
return u, ok
|
||||
func (n *UnaryOperationExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := u.V.Accept(v)
|
||||
n = newNod.(*UnaryOperationExpr)
|
||||
node, ok := n.V.Accept(v)
|
||||
if !ok {
|
||||
return u, false
|
||||
return n, false
|
||||
}
|
||||
u.V = node.(ExprNode)
|
||||
return v.Leave(u)
|
||||
n.V = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (u *UnaryOperationExpr) IsStatic() bool {
|
||||
return u.V.IsStatic()
|
||||
func (n *UnaryOperationExpr) IsStatic() bool {
|
||||
return n.V.IsStatic()
|
||||
}
|
||||
|
||||
// ValuesExpr is the expression used in INSERT VALUES
|
||||
@ -697,16 +785,18 @@ type ValuesExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (va *ValuesExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(va); skipChildren {
|
||||
return va, ok
|
||||
func (n *ValuesExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := va.Column.Accept(v)
|
||||
n = newNod.(*ValuesExpr)
|
||||
node, ok := n.Column.Accept(v)
|
||||
if !ok {
|
||||
return va, false
|
||||
return n, false
|
||||
}
|
||||
va.Column = node.(*ColumnName)
|
||||
return v.Leave(va)
|
||||
n.Column = node.(*ColumnName)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// VariableExpr is the expression for variable.
|
||||
@ -721,9 +811,11 @@ type VariableExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (va *VariableExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(va); skipChildren {
|
||||
return va, ok
|
||||
func (n *VariableExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(va)
|
||||
n = newNod.(*VariableExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
288
ast/functions.go
288
ast/functions.go
@ -26,44 +26,49 @@ var (
|
||||
_ FuncNode = &FuncConvertExpr{}
|
||||
_ FuncNode = &FuncCastExpr{}
|
||||
_ FuncNode = &FuncSubstringExpr{}
|
||||
_ FuncNode = &FuncLocateExpr{}
|
||||
_ FuncNode = &FuncTrimExpr{}
|
||||
_ FuncNode = &FuncDateArithExpr{}
|
||||
_ FuncNode = &AggregateFuncExpr{}
|
||||
)
|
||||
|
||||
// UnquoteString is not quoted when printed.
|
||||
type UnquoteString string
|
||||
|
||||
// FuncCallExpr is for function expression.
|
||||
type FuncCallExpr struct {
|
||||
funcNode
|
||||
// F is the function name.
|
||||
F string
|
||||
FnName string
|
||||
// Args is the function args.
|
||||
Args []ExprNode
|
||||
// Distinct only affetcts sum, avg, count, group_concat,
|
||||
// so we can ignore it in other functions
|
||||
Distinct bool
|
||||
}
|
||||
|
||||
// Accept implements Node interface.
|
||||
func (c *FuncCallExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(c); skipChildren {
|
||||
return c, ok
|
||||
func (n *FuncCallExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range c.Args {
|
||||
n = newNod.(*FuncCallExpr)
|
||||
for i, val := range n.Args {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return c, false
|
||||
return n, false
|
||||
}
|
||||
c.Args[i] = node.(ExprNode)
|
||||
n.Args[i] = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(c)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (c *FuncCallExpr) IsStatic() bool {
|
||||
v := builtin.Funcs[strings.ToLower(c.F)]
|
||||
func (n *FuncCallExpr) IsStatic() bool {
|
||||
v := builtin.Funcs[strings.ToLower(n.FnName)]
|
||||
if v.F == nil || !v.IsStatic {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, v := range c.Args {
|
||||
for _, v := range n.Args {
|
||||
if !v.IsStatic() {
|
||||
return false
|
||||
}
|
||||
@ -81,21 +86,23 @@ type FuncExtractExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ex *FuncExtractExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ex); skipChildren {
|
||||
return ex, ok
|
||||
func (n *FuncExtractExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ex.Date.Accept(v)
|
||||
n = newNod.(*FuncExtractExpr)
|
||||
node, ok := n.Date.Accept(v)
|
||||
if !ok {
|
||||
return ex, false
|
||||
return n, false
|
||||
}
|
||||
ex.Date = node.(ExprNode)
|
||||
return v.Leave(ex)
|
||||
n.Date = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (ex *FuncExtractExpr) IsStatic() bool {
|
||||
return ex.Date.IsStatic()
|
||||
func (n *FuncExtractExpr) IsStatic() bool {
|
||||
return n.Date.IsStatic()
|
||||
}
|
||||
|
||||
// FuncConvertExpr provides a way to convert data between different character sets.
|
||||
@ -109,21 +116,23 @@ type FuncConvertExpr struct {
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (f *FuncConvertExpr) IsStatic() bool {
|
||||
return f.Expr.IsStatic()
|
||||
func (n *FuncConvertExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic()
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (f *FuncConvertExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(f); skipChildren {
|
||||
return f, ok
|
||||
func (n *FuncConvertExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := f.Expr.Accept(v)
|
||||
n = newNod.(*FuncConvertExpr)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return f, false
|
||||
return n, false
|
||||
}
|
||||
f.Expr = node.(ExprNode)
|
||||
return v.Leave(f)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// CastFunctionType is the type for cast function.
|
||||
@ -149,21 +158,23 @@ type FuncCastExpr struct {
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (f *FuncCastExpr) IsStatic() bool {
|
||||
return f.Expr.IsStatic()
|
||||
func (n *FuncCastExpr) IsStatic() bool {
|
||||
return n.Expr.IsStatic()
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (f *FuncCastExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(f); skipChildren {
|
||||
return f, ok
|
||||
func (n *FuncCastExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := f.Expr.Accept(v)
|
||||
n = newNod.(*FuncCastExpr)
|
||||
node, ok := n.Expr.Accept(v)
|
||||
if !ok {
|
||||
return f, false
|
||||
return n, false
|
||||
}
|
||||
f.Expr = node.(ExprNode)
|
||||
return v.Leave(f)
|
||||
n.Expr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// FuncSubstringExpr returns the substring as specified.
|
||||
@ -177,31 +188,35 @@ type FuncSubstringExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (sf *FuncSubstringExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(sf); skipChildren {
|
||||
return sf, ok
|
||||
func (n *FuncSubstringExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := sf.StrExpr.Accept(v)
|
||||
n = newNod.(*FuncSubstringExpr)
|
||||
node, ok := n.StrExpr.Accept(v)
|
||||
if !ok {
|
||||
return sf, false
|
||||
return n, false
|
||||
}
|
||||
sf.StrExpr = node.(ExprNode)
|
||||
node, ok = sf.Pos.Accept(v)
|
||||
n.StrExpr = node.(ExprNode)
|
||||
node, ok = n.Pos.Accept(v)
|
||||
if !ok {
|
||||
return sf, false
|
||||
return n, false
|
||||
}
|
||||
sf.Pos = node.(ExprNode)
|
||||
node, ok = sf.Len.Accept(v)
|
||||
if !ok {
|
||||
return sf, false
|
||||
n.Pos = node.(ExprNode)
|
||||
if n.Len != nil {
|
||||
node, ok = n.Len.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Len = node.(ExprNode)
|
||||
}
|
||||
sf.Len = node.(ExprNode)
|
||||
return v.Leave(sf)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (sf *FuncSubstringExpr) IsStatic() bool {
|
||||
return sf.StrExpr.IsStatic() && sf.Pos.IsStatic() && sf.Len.IsStatic()
|
||||
func (n *FuncSubstringExpr) IsStatic() bool {
|
||||
return n.StrExpr.IsStatic() && n.Pos.IsStatic() && n.Len.IsStatic()
|
||||
}
|
||||
|
||||
// FuncSubstringIndexExpr returns the substring as specified.
|
||||
@ -215,26 +230,28 @@ type FuncSubstringIndexExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (si *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(si); skipChildren {
|
||||
return si, ok
|
||||
func (n *FuncSubstringIndexExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := si.StrExpr.Accept(v)
|
||||
n = newNod.(*FuncSubstringIndexExpr)
|
||||
node, ok := n.StrExpr.Accept(v)
|
||||
if !ok {
|
||||
return si, false
|
||||
return n, false
|
||||
}
|
||||
si.StrExpr = node.(ExprNode)
|
||||
node, ok = si.Delim.Accept(v)
|
||||
n.StrExpr = node.(ExprNode)
|
||||
node, ok = n.Delim.Accept(v)
|
||||
if !ok {
|
||||
return si, false
|
||||
return n, false
|
||||
}
|
||||
si.Delim = node.(ExprNode)
|
||||
node, ok = si.Count.Accept(v)
|
||||
n.Delim = node.(ExprNode)
|
||||
node, ok = n.Count.Accept(v)
|
||||
if !ok {
|
||||
return si, false
|
||||
return n, false
|
||||
}
|
||||
si.Count = node.(ExprNode)
|
||||
return v.Leave(si)
|
||||
n.Count = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// FuncLocateExpr returns the position of the first occurrence of substring.
|
||||
@ -248,26 +265,28 @@ type FuncLocateExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (le *FuncLocateExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(le); skipChildren {
|
||||
return le, ok
|
||||
func (n *FuncLocateExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := le.Str.Accept(v)
|
||||
n = newNod.(*FuncLocateExpr)
|
||||
node, ok := n.Str.Accept(v)
|
||||
if !ok {
|
||||
return le, false
|
||||
return n, false
|
||||
}
|
||||
le.Str = node.(ExprNode)
|
||||
node, ok = le.SubStr.Accept(v)
|
||||
n.Str = node.(ExprNode)
|
||||
node, ok = n.SubStr.Accept(v)
|
||||
if !ok {
|
||||
return le, false
|
||||
return n, false
|
||||
}
|
||||
le.SubStr = node.(ExprNode)
|
||||
node, ok = le.Pos.Accept(v)
|
||||
n.SubStr = node.(ExprNode)
|
||||
node, ok = n.Pos.Accept(v)
|
||||
if !ok {
|
||||
return le, false
|
||||
return n, false
|
||||
}
|
||||
le.Pos = node.(ExprNode)
|
||||
return v.Leave(le)
|
||||
n.Pos = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// TrimDirectionType is the type for trim direction.
|
||||
@ -295,27 +314,102 @@ type FuncTrimExpr struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (tf *FuncTrimExpr) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(tf); skipChildren {
|
||||
return tf, ok
|
||||
func (n *FuncTrimExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := tf.Str.Accept(v)
|
||||
n = newNod.(*FuncTrimExpr)
|
||||
node, ok := n.Str.Accept(v)
|
||||
if !ok {
|
||||
return tf, false
|
||||
return n, false
|
||||
}
|
||||
tf.Str = node.(ExprNode)
|
||||
node, ok = tf.RemStr.Accept(v)
|
||||
n.Str = node.(ExprNode)
|
||||
node, ok = n.RemStr.Accept(v)
|
||||
if !ok {
|
||||
return tf, false
|
||||
return n, false
|
||||
}
|
||||
tf.RemStr = node.(ExprNode)
|
||||
return v.Leave(tf)
|
||||
n.RemStr = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// IsStatic implements the ExprNode IsStatic interface.
|
||||
func (tf *FuncTrimExpr) IsStatic() bool {
|
||||
return tf.Str.IsStatic() && tf.RemStr.IsStatic()
|
||||
func (n *FuncTrimExpr) IsStatic() bool {
|
||||
return n.Str.IsStatic() && n.RemStr.IsStatic()
|
||||
}
|
||||
|
||||
// TypeStar is a special type for "*".
|
||||
type TypeStar string
|
||||
// DateArithType is type for DateArith option.
|
||||
type DateArithType byte
|
||||
|
||||
const (
|
||||
// DateAdd is to run date_add function option.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
|
||||
DateAdd DateArithType = iota + 1
|
||||
// DateSub is to run date_sub function option.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
|
||||
DateSub
|
||||
)
|
||||
|
||||
// FuncDateArithExpr is the struct for date arithmetic functions.
|
||||
type FuncDateArithExpr struct {
|
||||
funcNode
|
||||
|
||||
Op DateArithType
|
||||
Unit string
|
||||
Date ExprNode
|
||||
Interval ExprNode
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*FuncDateArithExpr)
|
||||
if n.Date != nil {
|
||||
node, ok := n.Date.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Date = node.(ExprNode)
|
||||
}
|
||||
if n.Interval != nil {
|
||||
node, ok := n.Date.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Date = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// AggregateFuncExpr represents aggregate function expression.
|
||||
type AggregateFuncExpr struct {
|
||||
funcNode
|
||||
// F is the function name.
|
||||
F string
|
||||
// Args is the function args.
|
||||
Args []ExprNode
|
||||
// If distinct is true, the function only aggregate distinct values.
|
||||
// For example, column c1 values are "1", "2", "2", "sum(c1)" is "5",
|
||||
// but "sum(distinct c1)" is "3".
|
||||
Distinct bool
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *AggregateFuncExpr) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*AggregateFuncExpr)
|
||||
for i, val := range n.Args {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Args[i] = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
310
ast/misc.go
310
ast/misc.go
@ -13,6 +13,8 @@
|
||||
|
||||
package ast
|
||||
|
||||
import "github.com/pingcap/tidb/mysql"
|
||||
|
||||
var (
|
||||
_ StmtNode = &ExplainStmt{}
|
||||
_ StmtNode = &PrepareStmt{}
|
||||
@ -26,7 +28,9 @@ var (
|
||||
_ StmtNode = &SetStmt{}
|
||||
_ StmtNode = &SetCharsetStmt{}
|
||||
_ StmtNode = &SetPwdStmt{}
|
||||
_ StmtNode = &CreateUserStmt{}
|
||||
_ StmtNode = &DoStmt{}
|
||||
_ StmtNode = &GrantStmt{}
|
||||
|
||||
_ Node = &VariableAssignment{}
|
||||
)
|
||||
@ -57,16 +61,18 @@ type ExplainStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (es *ExplainStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(es); skipChildren {
|
||||
return es, ok
|
||||
func (n *ExplainStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := es.Stmt.Accept(v)
|
||||
n = newNod.(*ExplainStmt)
|
||||
node, ok := n.Stmt.Accept(v)
|
||||
if !ok {
|
||||
return es, false
|
||||
return n, false
|
||||
}
|
||||
es.Stmt = node.(DMLNode)
|
||||
return v.Leave(es)
|
||||
n.Stmt = node.(DMLNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// PrepareStmt is a statement to prepares a SQL statement which contains placeholders,
|
||||
@ -83,16 +89,18 @@ type PrepareStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ps *PrepareStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ps); skipChildren {
|
||||
return ps, ok
|
||||
func (n *PrepareStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := ps.SQLVar.Accept(v)
|
||||
n = newNod.(*PrepareStmt)
|
||||
node, ok := n.SQLVar.Accept(v)
|
||||
if !ok {
|
||||
return ps, false
|
||||
return n, false
|
||||
}
|
||||
ps.SQLVar = node.(*VariableExpr)
|
||||
return v.Leave(ps)
|
||||
n.SQLVar = node.(*VariableExpr)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DeallocateStmt is a statement to release PreparedStmt.
|
||||
@ -105,11 +113,13 @@ type DeallocateStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ds *DeallocateStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ds); skipChildren {
|
||||
return ds, ok
|
||||
func (n *DeallocateStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(ds)
|
||||
n = newNod.(*DeallocateStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ExecuteStmt is a statement to execute PreparedStmt.
|
||||
@ -123,18 +133,20 @@ type ExecuteStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (es *ExecuteStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(es); skipChildren {
|
||||
return es, ok
|
||||
func (n *ExecuteStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range es.UsingVars {
|
||||
n = newNod.(*ExecuteStmt)
|
||||
for i, val := range n.UsingVars {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return es, false
|
||||
return n, false
|
||||
}
|
||||
es.UsingVars[i] = node.(ExprNode)
|
||||
n.UsingVars[i] = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(es)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ShowStmtType is the type for SHOW statement.
|
||||
@ -152,12 +164,13 @@ const (
|
||||
ShowVariables
|
||||
ShowCollation
|
||||
ShowCreateTable
|
||||
ShowGrants
|
||||
)
|
||||
|
||||
// ShowStmt is a statement to provide information about databases, tables, columns and so on.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/show.html
|
||||
type ShowStmt struct {
|
||||
stmtNode
|
||||
dmlNode
|
||||
|
||||
Tp ShowStmtType // Databases/Tables/Columns/....
|
||||
DBName string
|
||||
@ -165,6 +178,7 @@ type ShowStmt struct {
|
||||
Column *ColumnName // Used for `desc table column`.
|
||||
Flag int // Some flag parsed from sql, such as FULL.
|
||||
Full bool
|
||||
User string // Used for show grants.
|
||||
|
||||
// Used by show variables
|
||||
GlobalScope bool
|
||||
@ -173,39 +187,41 @@ type ShowStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (ss *ShowStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(ss); skipChildren {
|
||||
return ss, ok
|
||||
func (n *ShowStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
if ss.Table != nil {
|
||||
node, ok := ss.Table.Accept(v)
|
||||
n = newNod.(*ShowStmt)
|
||||
if n.Table != nil {
|
||||
node, ok := n.Table.Accept(v)
|
||||
if !ok {
|
||||
return ss, false
|
||||
return n, false
|
||||
}
|
||||
ss.Table = node.(*TableName)
|
||||
n.Table = node.(*TableName)
|
||||
}
|
||||
if ss.Column != nil {
|
||||
node, ok := ss.Column.Accept(v)
|
||||
if n.Column != nil {
|
||||
node, ok := n.Column.Accept(v)
|
||||
if !ok {
|
||||
return ss, false
|
||||
return n, false
|
||||
}
|
||||
ss.Column = node.(*ColumnName)
|
||||
n.Column = node.(*ColumnName)
|
||||
}
|
||||
if ss.Pattern != nil {
|
||||
node, ok := ss.Pattern.Accept(v)
|
||||
if n.Pattern != nil {
|
||||
node, ok := n.Pattern.Accept(v)
|
||||
if !ok {
|
||||
return ss, false
|
||||
return n, false
|
||||
}
|
||||
ss.Pattern = node.(*PatternLikeExpr)
|
||||
n.Pattern = node.(*PatternLikeExpr)
|
||||
}
|
||||
if ss.Where != nil {
|
||||
node, ok := ss.Where.Accept(v)
|
||||
if n.Where != nil {
|
||||
node, ok := n.Where.Accept(v)
|
||||
if !ok {
|
||||
return ss, false
|
||||
return n, false
|
||||
}
|
||||
ss.Where = node.(ExprNode)
|
||||
n.Where = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(ss)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// BeginStmt is a statement to start a new transaction.
|
||||
@ -215,11 +231,13 @@ type BeginStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (bs *BeginStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(bs); skipChildren {
|
||||
return bs, ok
|
||||
func (n *BeginStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(bs)
|
||||
n = newNod.(*BeginStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// CommitStmt is a statement to commit the current transaction.
|
||||
@ -229,11 +247,13 @@ type CommitStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (cs *CommitStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(cs); skipChildren {
|
||||
return cs, ok
|
||||
func (n *CommitStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(cs)
|
||||
n = newNod.(*CommitStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// RollbackStmt is a statement to roll back the current transaction.
|
||||
@ -243,11 +263,13 @@ type RollbackStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (rs *RollbackStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(rs); skipChildren {
|
||||
return rs, ok
|
||||
func (n *RollbackStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(rs)
|
||||
n = newNod.(*RollbackStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// UseStmt is a statement to use the DBName database as the current database.
|
||||
@ -259,11 +281,13 @@ type UseStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (us *UseStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(us); skipChildren {
|
||||
return us, ok
|
||||
func (n *UseStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(us)
|
||||
n = newNod.(*UseStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// VariableAssignment is a variable assignment struct.
|
||||
@ -276,16 +300,18 @@ type VariableAssignment struct {
|
||||
}
|
||||
|
||||
// Accept implements Node interface.
|
||||
func (va *VariableAssignment) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(va); skipChildren {
|
||||
return va, ok
|
||||
func (n *VariableAssignment) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
node, ok := va.Value.Accept(v)
|
||||
n = newNod.(*VariableAssignment)
|
||||
node, ok := n.Value.Accept(v)
|
||||
if !ok {
|
||||
return va, false
|
||||
return n, false
|
||||
}
|
||||
va.Value = node.(ExprNode)
|
||||
return v.Leave(va)
|
||||
n.Value = node.(ExprNode)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SetStmt is the statement to set variables.
|
||||
@ -296,18 +322,20 @@ type SetStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (set *SetStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(set); skipChildren {
|
||||
return set, ok
|
||||
func (n *SetStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range set.Variables {
|
||||
n = newNod.(*SetStmt)
|
||||
for i, val := range n.Variables {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return set, false
|
||||
return n, false
|
||||
}
|
||||
set.Variables[i] = node.(*VariableAssignment)
|
||||
n.Variables[i] = node.(*VariableAssignment)
|
||||
}
|
||||
return v.Leave(set)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SetCharsetStmt is a statement to assign values to character and collation variables.
|
||||
@ -320,11 +348,13 @@ type SetCharsetStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (set *SetCharsetStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(set); skipChildren {
|
||||
return set, ok
|
||||
func (n *SetCharsetStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(set)
|
||||
n = newNod.(*SetCharsetStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// SetPwdStmt is a statement to assign a password to user account.
|
||||
@ -337,11 +367,13 @@ type SetPwdStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (set *SetPwdStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(set); skipChildren {
|
||||
return set, ok
|
||||
func (n *SetPwdStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
return v.Leave(set)
|
||||
n = newNod.(*SetPwdStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// UserSpec is used for parsing create user statement.
|
||||
@ -353,10 +385,20 @@ type UserSpec struct {
|
||||
// CreateUserStmt creates user account.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html
|
||||
type CreateUserStmt struct {
|
||||
stmtNode
|
||||
|
||||
IfNotExists bool
|
||||
Specs []*UserSpec
|
||||
}
|
||||
|
||||
Text string
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *CreateUserStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*CreateUserStmt)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// DoStmt is the struct for DO statement.
|
||||
@ -367,16 +409,100 @@ type DoStmt struct {
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (do *DoStmt) Accept(v Visitor) (Node, bool) {
|
||||
if skipChildren, ok := v.Enter(do); skipChildren {
|
||||
return do, ok
|
||||
func (n *DoStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
for i, val := range do.Exprs {
|
||||
n = newNod.(*DoStmt)
|
||||
for i, val := range n.Exprs {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return do, false
|
||||
return n, false
|
||||
}
|
||||
do.Exprs[i] = node.(ExprNode)
|
||||
n.Exprs[i] = node.(ExprNode)
|
||||
}
|
||||
return v.Leave(do)
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// PrivElem is the privilege type and optional column list.
|
||||
type PrivElem struct {
|
||||
node
|
||||
Priv mysql.PrivilegeType
|
||||
Cols []*ColumnName
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *PrivElem) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*PrivElem)
|
||||
for i, val := range n.Cols {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Cols[i] = node.(*ColumnName)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
// ObjectTypeType is the type for object type.
|
||||
type ObjectTypeType int
|
||||
|
||||
const (
|
||||
// ObjectTypeNone is for empty object type.
|
||||
ObjectTypeNone ObjectTypeType = iota
|
||||
// ObjectTypeTable means the following object is a table.
|
||||
ObjectTypeTable
|
||||
)
|
||||
|
||||
// GrantLevelType is the type for grant level.
|
||||
type GrantLevelType int
|
||||
|
||||
const (
|
||||
// GrantLevelNone is the dummy const for default value.
|
||||
GrantLevelNone GrantLevelType = iota
|
||||
// GrantLevelGlobal means the privileges are administrative or apply to all databases on a given server.
|
||||
GrantLevelGlobal
|
||||
// GrantLevelDB means the privileges apply to all objects in a given database.
|
||||
GrantLevelDB
|
||||
// GrantLevelTable means the privileges apply to all columns in a given table.
|
||||
GrantLevelTable
|
||||
)
|
||||
|
||||
// GrantLevel is used for store the privilege scope.
|
||||
type GrantLevel struct {
|
||||
Level GrantLevelType
|
||||
DBName string
|
||||
TableName string
|
||||
}
|
||||
|
||||
// GrantStmt is the struct for GRANT statement.
|
||||
type GrantStmt struct {
|
||||
stmtNode
|
||||
|
||||
Privs []*PrivElem
|
||||
ObjectType ObjectTypeType
|
||||
Level *GrantLevel
|
||||
Users []*UserSpec
|
||||
}
|
||||
|
||||
// Accept implements Node Accept interface.
|
||||
func (n *GrantStmt) Accept(v Visitor) (Node, bool) {
|
||||
newNod, skipChildren := v.Enter(n)
|
||||
if skipChildren {
|
||||
return v.Leave(newNod)
|
||||
}
|
||||
n = newNod.(*GrantStmt)
|
||||
for i, val := range n.Privs {
|
||||
node, ok := val.Accept(v)
|
||||
if !ok {
|
||||
return n, false
|
||||
}
|
||||
n.Privs[i] = node.(*PrivElem)
|
||||
}
|
||||
return v.Leave(n)
|
||||
}
|
||||
|
||||
3917
ast/parser/parser.y
3917
ast/parser/parser.y
File diff suppressed because it is too large
Load Diff
@ -1,79 +0,0 @@
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
)
|
||||
|
||||
func TestT(t *testing.T) {
|
||||
TestingT(t)
|
||||
}
|
||||
|
||||
var _ = Suite(&testParserSuite{})
|
||||
|
||||
var _ = Suite(&testParserSuite{})
|
||||
|
||||
type testParserSuite struct {
|
||||
}
|
||||
|
||||
func (s *testParserSuite) TestSimple(c *C) {
|
||||
// Testcase for unreserved keywords
|
||||
unreservedKws := []string{
|
||||
"auto_increment", "after", "begin", "bit", "bool", "boolean", "charset", "columns", "commit",
|
||||
"date", "datetime", "deallocate", "do", "end", "engine", "engines", "execute", "first", "full",
|
||||
"local", "names", "offset", "password", "prepare", "quick", "rollback", "session", "signed",
|
||||
"start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown",
|
||||
"value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified",
|
||||
"collation", "comment", "avg_row_length", "checksum", "compression", "connection", "key_block_size",
|
||||
"max_rows", "min_rows", "national", "row", "quarter", "escape",
|
||||
}
|
||||
for _, kw := range unreservedKws {
|
||||
src := fmt.Sprintf("SELECT %s FROM tbl;", kw)
|
||||
l := NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
c.Assert(l.errs, HasLen, 0, Commentf("source %s", src))
|
||||
}
|
||||
|
||||
// Testcase for prepared statement
|
||||
src := "SELECT id+?, id+? from t;"
|
||||
l := NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
c.Assert(len(l.Stmts()), Equals, 1)
|
||||
|
||||
// Testcase for -- Comment and unary -- operator
|
||||
src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;"
|
||||
l = NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
c.Assert(len(l.Stmts()), Equals, 2)
|
||||
|
||||
// Testcase for CONVERT(expr,type)
|
||||
src = "SELECT CONVERT('111', SIGNED);"
|
||||
l = NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
st := l.Stmts()[0]
|
||||
ss, ok := st.(*ast.SelectStmt)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(len(ss.Fields), Equals, 1)
|
||||
cv, ok := ss.Fields[0].Expr.(*ast.FuncCastExpr)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction)
|
||||
|
||||
// For query start with comment
|
||||
srcs := []string{
|
||||
"/* some comments */ SELECT CONVERT('111', SIGNED) ;",
|
||||
"/* some comments */ /*comment*/ SELECT CONVERT('111', SIGNED) ;",
|
||||
"SELECT /*comment*/ CONVERT('111', SIGNED) ;",
|
||||
"SELECT CONVERT('111', /*comment*/ SIGNED) ;",
|
||||
"SELECT CONVERT('111', SIGNED) /*comment*/;",
|
||||
}
|
||||
for _, src := range srcs {
|
||||
l = NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
st = l.Stmts()[0]
|
||||
ss, ok = st.(*ast.SelectStmt)
|
||||
c.Assert(ok, IsTrue)
|
||||
}
|
||||
}
|
||||
1067
ast/parser/scanner.l
1067
ast/parser/scanner.l
File diff suppressed because it is too large
Load Diff
@ -24,6 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/optimizer"
|
||||
"github.com/pingcap/tidb/parser"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/stmt"
|
||||
@ -197,5 +198,7 @@ func statement(sql string) stmt.Statement {
|
||||
log.Debug("Compile", sql)
|
||||
lexer := parser.NewLexer(sql)
|
||||
parser.YYParse(lexer)
|
||||
return lexer.Stmts()[0].(stmt.Statement)
|
||||
compiler := &optimizer.Compiler{}
|
||||
stm, _ := compiler.Compile(lexer.Stmts()[0])
|
||||
return stm
|
||||
}
|
||||
|
||||
11
make.cmd
11
make.cmd
@ -18,17 +18,6 @@ DEL /F /A /Q y.output
|
||||
|
||||
golex -o parser/scanner.go parser/scanner.l
|
||||
|
||||
@echo [Ast-Parser]
|
||||
go get github.com/qiuyesuifeng/goyacc
|
||||
go get github.com/qiuyesuifeng/golex
|
||||
type nul >>temp.XXXXXX
|
||||
goyacc -o nul -xegen "temp.XXXXXX" parser/parser.y
|
||||
goyacc -o ast/parser/parser.go -xe "temp.XXXXXX" ast/parser/parser.y
|
||||
DEL /F /A /Q temp.XXXXXX
|
||||
DEL /F /A /Q y.output
|
||||
|
||||
golex -o ast/parser/scanner.go ast/parser/scanner.l
|
||||
|
||||
@echo [Build]
|
||||
godep go build -ldflags '%LDFLAGS%'
|
||||
|
||||
|
||||
26
optimizer/aggregator.go
Normal file
26
optimizer/aggregator.go
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
// Aggregator is the interface to
|
||||
// compute aggregate function result.
|
||||
type Aggregator interface {
|
||||
// Input adds an input value to aggregator.
|
||||
// The input values are accumulated in the aggregator.
|
||||
Input(in ...interface{}) error
|
||||
// Output uses input values to compute the aggregated result.
|
||||
Output() interface{}
|
||||
// Clear clears the input values.
|
||||
Clear()
|
||||
}
|
||||
132
optimizer/compiler.go
Normal file
132
optimizer/compiler.go
Normal file
@ -0,0 +1,132 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/stmt"
|
||||
)
|
||||
|
||||
// Compiler compiles ast.Node into an executable statement.
|
||||
type Compiler struct {
|
||||
converter *expressionConverter
|
||||
}
|
||||
|
||||
// Compile compiles a ast.Node into an executable statement.
|
||||
func (com *Compiler) Compile(node ast.Node) (stmt.Statement, error) {
|
||||
validator := &validator{}
|
||||
if _, ok := node.Accept(validator); !ok {
|
||||
return nil, errors.Trace(validator.err)
|
||||
}
|
||||
|
||||
// binder := &InfoBinder{}
|
||||
// if _, ok := node.Accept(validator); !ok {
|
||||
// return nil, errors.Trace(binder.Err)
|
||||
// }
|
||||
|
||||
tpComputer := &typeComputer{}
|
||||
if _, ok := node.Accept(tpComputer); !ok {
|
||||
return nil, errors.Trace(tpComputer.err)
|
||||
}
|
||||
c := newExpressionConverter()
|
||||
com.converter = c
|
||||
switch v := node.(type) {
|
||||
case *ast.InsertStmt:
|
||||
return convertInsert(c, v)
|
||||
case *ast.DeleteStmt:
|
||||
return convertDelete(c, v)
|
||||
case *ast.UpdateStmt:
|
||||
return convertUpdate(c, v)
|
||||
case *ast.SelectStmt:
|
||||
return convertSelect(c, v)
|
||||
case *ast.UnionStmt:
|
||||
return convertUnion(c, v)
|
||||
case *ast.CreateDatabaseStmt:
|
||||
return convertCreateDatabase(c, v)
|
||||
case *ast.DropDatabaseStmt:
|
||||
return convertDropDatabase(c, v)
|
||||
case *ast.CreateTableStmt:
|
||||
return convertCreateTable(c, v)
|
||||
case *ast.DropTableStmt:
|
||||
return convertDropTable(c, v)
|
||||
case *ast.CreateIndexStmt:
|
||||
return convertCreateIndex(c, v)
|
||||
case *ast.DropIndexStmt:
|
||||
return convertDropIndex(c, v)
|
||||
case *ast.AlterTableStmt:
|
||||
return convertAlterTable(c, v)
|
||||
case *ast.TruncateTableStmt:
|
||||
return convertTruncateTable(c, v)
|
||||
case *ast.ExplainStmt:
|
||||
return convertExplain(c, v)
|
||||
case *ast.PrepareStmt:
|
||||
return convertPrepare(c, v)
|
||||
case *ast.DeallocateStmt:
|
||||
return convertDeallocate(c, v)
|
||||
case *ast.ExecuteStmt:
|
||||
return convertExecute(c, v)
|
||||
case *ast.ShowStmt:
|
||||
return convertShow(c, v)
|
||||
case *ast.BeginStmt:
|
||||
return convertBegin(c, v)
|
||||
case *ast.CommitStmt:
|
||||
return convertCommit(c, v)
|
||||
case *ast.RollbackStmt:
|
||||
return convertRollback(c, v)
|
||||
case *ast.UseStmt:
|
||||
return convertUse(c, v)
|
||||
case *ast.SetStmt:
|
||||
return convertSet(c, v)
|
||||
case *ast.SetCharsetStmt:
|
||||
return convertSetCharset(c, v)
|
||||
case *ast.SetPwdStmt:
|
||||
return convertSetPwd(c, v)
|
||||
case *ast.CreateUserStmt:
|
||||
return convertCreateUser(c, v)
|
||||
case *ast.DoStmt:
|
||||
return convertDo(c, v)
|
||||
case *ast.GrantStmt:
|
||||
return convertGrant(c, v)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type paramMarkers []*ast.ParamMarkerExpr
|
||||
|
||||
func (p paramMarkers) Len() int {
|
||||
return len(p)
|
||||
}
|
||||
|
||||
func (p paramMarkers) Less(i, j int) bool {
|
||||
return p[i].Offset < p[j].Offset
|
||||
}
|
||||
|
||||
func (p paramMarkers) Swap(i, j int) {
|
||||
p[i], p[j] = p[j], p[i]
|
||||
}
|
||||
|
||||
// ParamMarkers returns parameter markers for prepared statement.
|
||||
func (com *Compiler) ParamMarkers() []*expression.ParamMarker {
|
||||
c := com.converter
|
||||
sort.Sort(c.paramMarkers)
|
||||
oldMarkers := make([]*expression.ParamMarker, len(c.paramMarkers))
|
||||
for i, val := range c.paramMarkers {
|
||||
oldMarkers[i] = c.exprMap[val].(*expression.ParamMarker)
|
||||
}
|
||||
return oldMarkers
|
||||
}
|
||||
425
optimizer/convert_expr.go
Normal file
425
optimizer/convert_expr.go
Normal file
@ -0,0 +1,425 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/expression/subquery"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func convertExpr(converter *expressionConverter, expr ast.ExprNode) (expression.Expression, error) {
|
||||
expr.Accept(converter)
|
||||
if converter.err != nil {
|
||||
return nil, errors.Trace(converter.err)
|
||||
}
|
||||
return converter.exprMap[expr], nil
|
||||
}
|
||||
|
||||
// expressionConverter converts ast expression to
|
||||
// old expression for transition state.
|
||||
type expressionConverter struct {
|
||||
exprMap map[ast.Node]expression.Expression
|
||||
paramMarkers paramMarkers
|
||||
err error
|
||||
}
|
||||
|
||||
func newExpressionConverter() *expressionConverter {
|
||||
return &expressionConverter{
|
||||
exprMap: map[ast.Node]expression.Expression{},
|
||||
}
|
||||
}
|
||||
|
||||
// Enter implements ast.Visitor interface.
|
||||
func (c *expressionConverter) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
|
||||
return in, false
|
||||
}
|
||||
|
||||
// Leave implements ast.Visitor interface.
|
||||
func (c *expressionConverter) Leave(in ast.Node) (out ast.Node, ok bool) {
|
||||
switch v := in.(type) {
|
||||
case *ast.ValueExpr:
|
||||
c.value(v)
|
||||
case *ast.BetweenExpr:
|
||||
c.between(v)
|
||||
case *ast.BinaryOperationExpr:
|
||||
c.binaryOperation(v)
|
||||
case *ast.WhenClause:
|
||||
c.whenClause(v)
|
||||
case *ast.CaseExpr:
|
||||
c.caseExpr(v)
|
||||
case *ast.SubqueryExpr:
|
||||
c.subquery(v)
|
||||
case *ast.CompareSubqueryExpr:
|
||||
c.compareSubquery(v)
|
||||
case *ast.ColumnNameExpr:
|
||||
c.columnNameExpr(v)
|
||||
case *ast.DefaultExpr:
|
||||
c.defaultExpr(v)
|
||||
case *ast.IdentifierExpr:
|
||||
c.identifier(v)
|
||||
case *ast.ExistsSubqueryExpr:
|
||||
c.existsSubquery(v)
|
||||
case *ast.PatternInExpr:
|
||||
c.patternIn(v)
|
||||
case *ast.IsNullExpr:
|
||||
c.isNull(v)
|
||||
case *ast.IsTruthExpr:
|
||||
c.isTruth(v)
|
||||
case *ast.PatternLikeExpr:
|
||||
c.patternLike(v)
|
||||
case *ast.ParamMarkerExpr:
|
||||
c.paramMarker(v)
|
||||
case *ast.ParenthesesExpr:
|
||||
c.parentheses(v)
|
||||
case *ast.PositionExpr:
|
||||
c.position(v)
|
||||
case *ast.PatternRegexpExpr:
|
||||
c.patternRegexp(v)
|
||||
case *ast.RowExpr:
|
||||
c.row(v)
|
||||
case *ast.UnaryOperationExpr:
|
||||
c.unaryOperation(v)
|
||||
case *ast.ValuesExpr:
|
||||
c.values(v)
|
||||
case *ast.VariableExpr:
|
||||
c.variable(v)
|
||||
case *ast.FuncCallExpr:
|
||||
c.funcCall(v)
|
||||
case *ast.FuncExtractExpr:
|
||||
c.funcExtract(v)
|
||||
case *ast.FuncConvertExpr:
|
||||
c.funcConvert(v)
|
||||
case *ast.FuncCastExpr:
|
||||
c.funcCast(v)
|
||||
case *ast.FuncSubstringExpr:
|
||||
c.funcSubstring(v)
|
||||
case *ast.FuncLocateExpr:
|
||||
c.funcLocate(v)
|
||||
case *ast.FuncTrimExpr:
|
||||
c.funcTrim(v)
|
||||
case *ast.FuncDateArithExpr:
|
||||
c.funcDateArith(v)
|
||||
case *ast.AggregateFuncExpr:
|
||||
c.aggregateFunc(v)
|
||||
}
|
||||
return in, c.err == nil
|
||||
}
|
||||
|
||||
func (c *expressionConverter) value(v *ast.ValueExpr) {
|
||||
c.exprMap[v] = expression.Value{Val: v.GetValue()}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) between(v *ast.BetweenExpr) {
|
||||
oldExpr := c.exprMap[v.Expr]
|
||||
oldLo := c.exprMap[v.Left]
|
||||
oldHi := c.exprMap[v.Right]
|
||||
oldBetween, err := expression.NewBetween(oldExpr, oldLo, oldHi, v.Not)
|
||||
if err != nil {
|
||||
c.err = err
|
||||
return
|
||||
}
|
||||
c.exprMap[v] = oldBetween
|
||||
}
|
||||
|
||||
func (c *expressionConverter) binaryOperation(v *ast.BinaryOperationExpr) {
|
||||
oldeLeft := c.exprMap[v.L]
|
||||
oldRight := c.exprMap[v.R]
|
||||
oldBinop := expression.NewBinaryOperation(v.Op, oldeLeft, oldRight)
|
||||
c.exprMap[v] = oldBinop
|
||||
}
|
||||
|
||||
func (c *expressionConverter) whenClause(v *ast.WhenClause) {
|
||||
oldExpr := c.exprMap[v.Expr]
|
||||
oldResult := c.exprMap[v.Result]
|
||||
oldWhenClause := &expression.WhenClause{Expr: oldExpr, Result: oldResult}
|
||||
c.exprMap[v] = oldWhenClause
|
||||
}
|
||||
|
||||
func (c *expressionConverter) caseExpr(v *ast.CaseExpr) {
|
||||
oldValue := c.exprMap[v.Value]
|
||||
oldWhenClauses := make([]*expression.WhenClause, len(v.WhenClauses))
|
||||
for i, val := range v.WhenClauses {
|
||||
oldWhenClauses[i] = c.exprMap[val].(*expression.WhenClause)
|
||||
}
|
||||
oldElse := c.exprMap[v.ElseClause]
|
||||
oldCaseExpr := &expression.FunctionCase{
|
||||
Value: oldValue,
|
||||
WhenClauses: oldWhenClauses,
|
||||
ElseClause: oldElse,
|
||||
}
|
||||
c.exprMap[v] = oldCaseExpr
|
||||
}
|
||||
|
||||
func (c *expressionConverter) subquery(v *ast.SubqueryExpr) {
|
||||
oldSubquery := &subquery.SubQuery{}
|
||||
switch x := v.Query.(type) {
|
||||
case *ast.SelectStmt:
|
||||
oldSelect, err := convertSelect(c, x)
|
||||
if err != nil {
|
||||
c.err = err
|
||||
return
|
||||
}
|
||||
oldSubquery.Stmt = oldSelect
|
||||
case *ast.UnionStmt:
|
||||
oldUnion, err := convertUnion(c, x)
|
||||
if err != nil {
|
||||
c.err = err
|
||||
return
|
||||
}
|
||||
oldSubquery.Stmt = oldUnion
|
||||
}
|
||||
c.exprMap[v] = oldSubquery
|
||||
}
|
||||
|
||||
func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) {
|
||||
expr := c.exprMap[v.L]
|
||||
subquery := c.exprMap[v.R]
|
||||
oldCmpSubquery := expression.NewCompareSubQuery(v.Op, expr, subquery.(expression.SubQuery), v.All)
|
||||
c.exprMap[v] = oldCmpSubquery
|
||||
}
|
||||
|
||||
func joinColumnName(columnName *ast.ColumnName) string {
|
||||
var originStrs []string
|
||||
if columnName.Schema.O != "" {
|
||||
originStrs = append(originStrs, columnName.Schema.O)
|
||||
}
|
||||
if columnName.Table.O != "" {
|
||||
originStrs = append(originStrs, columnName.Table.O)
|
||||
}
|
||||
originStrs = append(originStrs, columnName.Name.O)
|
||||
return strings.Join(originStrs, ".")
|
||||
}
|
||||
|
||||
func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) {
|
||||
ident := &expression.Ident{}
|
||||
ident.CIStr = model.NewCIStr(joinColumnName(v.Name))
|
||||
c.exprMap[v] = ident
|
||||
}
|
||||
|
||||
func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) {
|
||||
oldDefault := &expression.Default{}
|
||||
if v.Name != nil {
|
||||
oldDefault.Name = joinColumnName(v.Name)
|
||||
}
|
||||
c.exprMap[v] = oldDefault
|
||||
}
|
||||
|
||||
func (c *expressionConverter) identifier(v *ast.IdentifierExpr) {
|
||||
oldIdent := &expression.Ident{}
|
||||
oldIdent.CIStr = v.Name
|
||||
c.exprMap[v] = oldIdent
|
||||
}
|
||||
|
||||
func (c *expressionConverter) existsSubquery(v *ast.ExistsSubqueryExpr) {
|
||||
subquery := c.exprMap[v.Sel].(expression.SubQuery)
|
||||
c.exprMap[v] = &expression.ExistsSubQuery{Sel: subquery}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) patternIn(v *ast.PatternInExpr) {
|
||||
oldPatternIn := &expression.PatternIn{Not: v.Not}
|
||||
if v.Sel != nil {
|
||||
oldPatternIn.Sel = c.exprMap[v.Sel].(expression.SubQuery)
|
||||
}
|
||||
oldPatternIn.Expr = c.exprMap[v.Expr]
|
||||
if v.List != nil {
|
||||
oldPatternIn.List = make([]expression.Expression, len(v.List))
|
||||
for i, v := range v.List {
|
||||
oldPatternIn.List[i] = c.exprMap[v]
|
||||
}
|
||||
}
|
||||
c.exprMap[v] = oldPatternIn
|
||||
}
|
||||
|
||||
func (c *expressionConverter) isNull(v *ast.IsNullExpr) {
|
||||
oldIsNull := &expression.IsNull{Not: v.Not}
|
||||
oldIsNull.Expr = c.exprMap[v.Expr]
|
||||
c.exprMap[v] = oldIsNull
|
||||
}
|
||||
|
||||
func (c *expressionConverter) isTruth(v *ast.IsTruthExpr) {
|
||||
oldIsTruth := &expression.IsTruth{Not: v.Not, True: v.True}
|
||||
oldIsTruth.Expr = c.exprMap[v.Expr]
|
||||
c.exprMap[v] = oldIsTruth
|
||||
}
|
||||
|
||||
func (c *expressionConverter) patternLike(v *ast.PatternLikeExpr) {
|
||||
oldPatternLike := &expression.PatternLike{
|
||||
Not: v.Not,
|
||||
Escape: v.Escape,
|
||||
Expr: c.exprMap[v.Expr],
|
||||
Pattern: c.exprMap[v.Pattern],
|
||||
}
|
||||
c.exprMap[v] = oldPatternLike
|
||||
}
|
||||
|
||||
func (c *expressionConverter) paramMarker(v *ast.ParamMarkerExpr) {
|
||||
if c.exprMap[v] == nil {
|
||||
c.exprMap[v] = &expression.ParamMarker{}
|
||||
c.paramMarkers = append(c.paramMarkers, v)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) parentheses(v *ast.ParenthesesExpr) {
|
||||
oldExpr := c.exprMap[v.Expr]
|
||||
c.exprMap[v] = &expression.PExpr{Expr: oldExpr}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) position(v *ast.PositionExpr) {
|
||||
c.exprMap[v] = &expression.Position{N: v.N, Name: v.Name}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) patternRegexp(v *ast.PatternRegexpExpr) {
|
||||
oldPatternRegexp := &expression.PatternRegexp{
|
||||
Not: v.Not,
|
||||
Expr: c.exprMap[v.Expr],
|
||||
Pattern: c.exprMap[v.Pattern],
|
||||
}
|
||||
c.exprMap[v] = oldPatternRegexp
|
||||
}
|
||||
|
||||
func (c *expressionConverter) row(v *ast.RowExpr) {
|
||||
oldRow := &expression.Row{}
|
||||
oldRow.Values = make([]expression.Expression, len(v.Values))
|
||||
for i, val := range v.Values {
|
||||
oldRow.Values[i] = c.exprMap[val]
|
||||
}
|
||||
c.exprMap[v] = oldRow
|
||||
}
|
||||
|
||||
func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) {
|
||||
oldUnary := &expression.UnaryOperation{
|
||||
Op: v.Op,
|
||||
V: c.exprMap[v.V],
|
||||
}
|
||||
c.exprMap[v] = oldUnary
|
||||
}
|
||||
|
||||
func (c *expressionConverter) values(v *ast.ValuesExpr) {
|
||||
nameStr := joinColumnName(v.Column)
|
||||
c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) variable(v *ast.VariableExpr) {
|
||||
c.exprMap[v] = &expression.Variable{
|
||||
IsGlobal: v.IsGlobal,
|
||||
IsSystem: v.IsSystem,
|
||||
Name: v.Name,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcCall(v *ast.FuncCallExpr) {
|
||||
oldCall := &expression.Call{
|
||||
F: v.FnName,
|
||||
}
|
||||
oldCall.Args = make([]expression.Expression, len(v.Args))
|
||||
for i, val := range v.Args {
|
||||
oldCall.Args[i] = c.exprMap[val]
|
||||
}
|
||||
c.exprMap[v] = oldCall
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcExtract(v *ast.FuncExtractExpr) {
|
||||
oldExtract := &expression.Extract{Unit: v.Unit}
|
||||
oldExtract.Date = c.exprMap[v.Date]
|
||||
c.exprMap[v] = oldExtract
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcConvert(v *ast.FuncConvertExpr) {
|
||||
c.exprMap[v] = &expression.FunctionConvert{
|
||||
Expr: c.exprMap[v.Expr],
|
||||
Charset: v.Charset,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcCast(v *ast.FuncCastExpr) {
|
||||
oldCast := &expression.FunctionCast{
|
||||
Expr: c.exprMap[v.Expr],
|
||||
Tp: v.Tp,
|
||||
}
|
||||
switch v.FunctionType {
|
||||
case ast.CastBinaryOperator:
|
||||
oldCast.FunctionType = expression.BinaryOperator
|
||||
case ast.CastConvertFunction:
|
||||
oldCast.FunctionType = expression.ConvertFunction
|
||||
case ast.CastFunction:
|
||||
oldCast.FunctionType = expression.CastFunction
|
||||
}
|
||||
c.exprMap[v] = oldCast
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcSubstring(v *ast.FuncSubstringExpr) {
|
||||
oldSubstring := &expression.FunctionSubstring{
|
||||
Len: c.exprMap[v.Len],
|
||||
Pos: c.exprMap[v.Pos],
|
||||
StrExpr: c.exprMap[v.StrExpr],
|
||||
}
|
||||
c.exprMap[v] = oldSubstring
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcLocate(v *ast.FuncLocateExpr) {
|
||||
oldLocate := &expression.FunctionLocate{
|
||||
Pos: c.exprMap[v.Pos],
|
||||
Str: c.exprMap[v.Str],
|
||||
SubStr: c.exprMap[v.SubStr],
|
||||
}
|
||||
c.exprMap[v] = oldLocate
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcTrim(v *ast.FuncTrimExpr) {
|
||||
oldTrim := &expression.FunctionTrim{
|
||||
Str: c.exprMap[v.Str],
|
||||
RemStr: c.exprMap[v.RemStr],
|
||||
}
|
||||
switch v.Direction {
|
||||
case ast.TrimBoth:
|
||||
oldTrim.Direction = expression.TrimBoth
|
||||
case ast.TrimBothDefault:
|
||||
oldTrim.Direction = expression.TrimBothDefault
|
||||
case ast.TrimLeading:
|
||||
oldTrim.Direction = expression.TrimLeading
|
||||
case ast.TrimTrailing:
|
||||
oldTrim.Direction = expression.TrimTrailing
|
||||
}
|
||||
c.exprMap[v] = oldTrim
|
||||
}
|
||||
|
||||
func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) {
|
||||
oldDateArith := &expression.DateArith{
|
||||
Unit: v.Unit,
|
||||
Date: c.exprMap[v.Date],
|
||||
Interval: c.exprMap[v.Interval],
|
||||
}
|
||||
switch v.Op {
|
||||
case ast.DateAdd:
|
||||
oldDateArith.Op = expression.DateAdd
|
||||
case ast.DateSub:
|
||||
oldDateArith.Op = expression.DateSub
|
||||
}
|
||||
c.exprMap[v] = oldDateArith
|
||||
}
|
||||
|
||||
func (c *expressionConverter) aggregateFunc(v *ast.AggregateFuncExpr) {
|
||||
oldAggregate := &expression.Call{
|
||||
F: v.F,
|
||||
Distinct: v.Distinct,
|
||||
}
|
||||
for _, val := range v.Args {
|
||||
oldAggregate.Args = append(oldAggregate.Args, c.exprMap[val])
|
||||
}
|
||||
c.exprMap[v] = oldAggregate
|
||||
}
|
||||
1024
optimizer/convert_stmt.go
Normal file
1024
optimizer/convert_stmt.go
Normal file
File diff suppressed because it is too large
Load Diff
317
optimizer/evaluator.go
Normal file
317
optimizer/evaluator.go
Normal file
@ -0,0 +1,317 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/plan"
|
||||
)
|
||||
|
||||
// Evaluator is a ast visitor that evaluates an expression.
|
||||
type Evaluator struct {
|
||||
// columnMap is the map from ColumnName to the position of the rowStack.
|
||||
// It is used to find the value of the column.
|
||||
columnMap map[*ast.ColumnName]position
|
||||
// rowStack is the current row values while scanning.
|
||||
// It should be updated after scaned a new row.
|
||||
rowStack []*plan.Row
|
||||
|
||||
// the map from AggregateFuncExpr to aggregator index.
|
||||
aggregateMap map[*ast.AggregateFuncExpr]int
|
||||
|
||||
// aggregators for the current row
|
||||
// only outer query aggregate functions are handled.
|
||||
aggregators []Aggregator
|
||||
|
||||
// when aggregation phase is done, the input is
|
||||
aggregateDone bool
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
type position struct {
|
||||
stackOffset int
|
||||
fieldList bool
|
||||
columnOffset int
|
||||
}
|
||||
|
||||
// Enter implements ast.Visitor interface.
|
||||
func (e *Evaluator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
|
||||
return
|
||||
}
|
||||
|
||||
// Leave implements ast.Visitor interface.
|
||||
func (e *Evaluator) Leave(in ast.Node) (out ast.Node, ok bool) {
|
||||
switch v := in.(type) {
|
||||
case *ast.ValueExpr:
|
||||
ok = true
|
||||
case *ast.BetweenExpr:
|
||||
ok = e.between(v)
|
||||
case *ast.BinaryOperationExpr:
|
||||
ok = e.binaryOperation(v)
|
||||
case *ast.WhenClause:
|
||||
ok = e.whenClause(v)
|
||||
case *ast.CaseExpr:
|
||||
ok = e.caseExpr(v)
|
||||
case *ast.SubqueryExpr:
|
||||
ok = e.subquery(v)
|
||||
case *ast.CompareSubqueryExpr:
|
||||
ok = e.compareSubquery(v)
|
||||
case *ast.ColumnNameExpr:
|
||||
ok = e.columnName(v)
|
||||
case *ast.DefaultExpr:
|
||||
ok = e.defaultExpr(v)
|
||||
case *ast.IdentifierExpr:
|
||||
ok = e.identifier(v)
|
||||
case *ast.ExistsSubqueryExpr:
|
||||
ok = e.existsSubquery(v)
|
||||
case *ast.PatternInExpr:
|
||||
ok = e.patternIn(v)
|
||||
case *ast.IsNullExpr:
|
||||
ok = e.isNull(v)
|
||||
case *ast.IsTruthExpr:
|
||||
ok = e.isTruth(v)
|
||||
case *ast.PatternLikeExpr:
|
||||
ok = e.patternLike(v)
|
||||
case *ast.ParamMarkerExpr:
|
||||
ok = e.paramMarker(v)
|
||||
case *ast.ParenthesesExpr:
|
||||
ok = e.parentheses(v)
|
||||
case *ast.PositionExpr:
|
||||
ok = e.position(v)
|
||||
case *ast.PatternRegexpExpr:
|
||||
ok = e.patternRegexp(v)
|
||||
case *ast.RowExpr:
|
||||
ok = e.row(v)
|
||||
case *ast.UnaryOperationExpr:
|
||||
ok = e.unaryOperation(v)
|
||||
case *ast.ValuesExpr:
|
||||
ok = e.values(v)
|
||||
case *ast.VariableExpr:
|
||||
ok = e.variable(v)
|
||||
case *ast.FuncCallExpr:
|
||||
ok = e.funcCall(v)
|
||||
case *ast.FuncExtractExpr:
|
||||
ok = e.funcExtract(v)
|
||||
case *ast.FuncConvertExpr:
|
||||
ok = e.funcConvert(v)
|
||||
case *ast.FuncCastExpr:
|
||||
ok = e.funcCast(v)
|
||||
case *ast.FuncSubstringExpr:
|
||||
ok = e.funcSubstring(v)
|
||||
case *ast.FuncLocateExpr:
|
||||
ok = e.funcLocate(v)
|
||||
case *ast.FuncTrimExpr:
|
||||
ok = e.funcTrim(v)
|
||||
case *ast.AggregateFuncExpr:
|
||||
ok = e.aggregateFunc(v)
|
||||
}
|
||||
out = in
|
||||
return
|
||||
}
|
||||
|
||||
func checkAllOneColumn(exprs ...ast.ExprNode) bool {
|
||||
for _, expr := range exprs {
|
||||
switch v := expr.(type) {
|
||||
case *ast.RowExpr:
|
||||
return false
|
||||
case *ast.SubqueryExpr:
|
||||
if len(v.Query.GetResultFields()) != 1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) between(v *ast.BetweenExpr) bool {
|
||||
if !checkAllOneColumn(v.Expr, v.Left, v.Right) {
|
||||
e.err = errors.Errorf("Operand should contain 1 column(s)")
|
||||
return false
|
||||
}
|
||||
|
||||
var l, r ast.ExprNode
|
||||
op := opcode.AndAnd
|
||||
|
||||
if v.Not {
|
||||
// v < lv || v > rv
|
||||
op = opcode.OrOr
|
||||
l = &ast.BinaryOperationExpr{Op: opcode.LT, L: v.Expr, R: v.Left}
|
||||
r = &ast.BinaryOperationExpr{Op: opcode.GT, L: v.Expr, R: v.Right}
|
||||
} else {
|
||||
// v >= lv && v <= rv
|
||||
l = &ast.BinaryOperationExpr{Op: opcode.GE, L: v.Expr, R: v.Left}
|
||||
r = &ast.BinaryOperationExpr{Op: opcode.LE, L: v.Expr, R: v.Right}
|
||||
}
|
||||
|
||||
ret := &ast.BinaryOperationExpr{Op: op, L: l, R: r}
|
||||
ret.Accept(e)
|
||||
return e.err == nil
|
||||
}
|
||||
|
||||
func columnCount(e ast.ExprNode) (int, error) {
|
||||
switch x := e.(type) {
|
||||
case *ast.RowExpr:
|
||||
n := len(x.Values)
|
||||
if n <= 1 {
|
||||
return 0, errors.Errorf("Operand should contain >= 2 columns for Row")
|
||||
}
|
||||
return n, nil
|
||||
case *ast.SubqueryExpr:
|
||||
return len(x.Query.GetResultFields()), nil
|
||||
default:
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
func hasSameColumnCount(e ast.ExprNode, args ...ast.ExprNode) error {
|
||||
l, err := columnCount(e)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
var n int
|
||||
for _, arg := range args {
|
||||
n, err = columnCount(arg)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
if n != l {
|
||||
return errors.Errorf("Operand should contain %d column(s)", l)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Evaluator) whenClause(v *ast.WhenClause) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) caseExpr(v *ast.CaseExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) subquery(v *ast.SubqueryExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) compareSubquery(v *ast.CompareSubqueryExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) columnName(v *ast.ColumnNameExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) defaultExpr(v *ast.DefaultExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) identifier(v *ast.IdentifierExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) existsSubquery(v *ast.ExistsSubqueryExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) patternIn(v *ast.PatternInExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) isNull(v *ast.IsNullExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) isTruth(v *ast.IsTruthExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) patternLike(v *ast.PatternLikeExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) paramMarker(v *ast.ParamMarkerExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) parentheses(v *ast.ParenthesesExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) position(v *ast.PositionExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) patternRegexp(v *ast.PatternRegexpExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) row(v *ast.RowExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) unaryOperation(v *ast.UnaryOperationExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) values(v *ast.ValuesExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) variable(v *ast.VariableExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcCall(v *ast.FuncCallExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcExtract(v *ast.FuncExtractExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcConvert(v *ast.FuncConvertExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcCast(v *ast.FuncCastExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcSubstring(v *ast.FuncSubstringExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcLocate(v *ast.FuncLocateExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) funcTrim(v *ast.FuncTrimExpr) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
|
||||
idx := e.aggregateMap[v]
|
||||
aggr := e.aggregators[idx]
|
||||
if e.aggregateDone {
|
||||
v.SetValue(aggr.Output())
|
||||
return true
|
||||
}
|
||||
// TODO: currently only single argument aggregate functions are supported.
|
||||
e.err = aggr.Input(v.Args[0].GetValue())
|
||||
return e.err == nil
|
||||
}
|
||||
527
optimizer/evaluator_binop.go
Normal file
527
optimizer/evaluator_binop.go
Normal file
@ -0,0 +1,527 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
const (
|
||||
zeroI64 int64 = 0
|
||||
oneI64 int64 = 1
|
||||
)
|
||||
|
||||
func (e *Evaluator) binaryOperation(o *ast.BinaryOperationExpr) bool {
|
||||
// all operands must have same column.
|
||||
if e.err = hasSameColumnCount(o.L, o.R); e.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// row constructor only supports comparison operation.
|
||||
switch o.Op {
|
||||
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
|
||||
default:
|
||||
if !checkAllOneColumn(o.L) {
|
||||
e.err = errors.Errorf("Operand should contain 1 column(s)")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
leftVal, err := types.Convert(o.L.GetValue(), o.GetType())
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
rightVal, err := types.Convert(o.R.GetValue(), o.GetType())
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
if leftVal == nil || rightVal == nil {
|
||||
o.SetValue(nil)
|
||||
return true
|
||||
}
|
||||
|
||||
switch o.Op {
|
||||
case opcode.AndAnd, opcode.OrOr, opcode.LogicXor:
|
||||
return e.handleLogicOperation(o)
|
||||
case opcode.LT, opcode.LE, opcode.GE, opcode.GT, opcode.EQ, opcode.NE, opcode.NullEQ:
|
||||
return e.handleComparisonOp(o)
|
||||
case opcode.RightShift, opcode.LeftShift, opcode.And, opcode.Or, opcode.Xor:
|
||||
// TODO: MySQL doesn't support and not, we should remove it later.
|
||||
return e.handleBitOp(o)
|
||||
case opcode.Plus, opcode.Minus, opcode.Mod, opcode.Div, opcode.Mul, opcode.IntDiv:
|
||||
return e.handleArithmeticOp(o)
|
||||
default:
|
||||
panic("should never happen")
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Evaluator) handleLogicOperation(o *ast.BinaryOperationExpr) bool {
|
||||
leftVal, err := types.Convert(o.L.GetValue(), o.GetType())
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
rightVal, err := types.Convert(o.R.GetValue(), o.GetType())
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
if leftVal == nil || rightVal == nil {
|
||||
o.SetValue(nil)
|
||||
return true
|
||||
}
|
||||
var boolVal bool
|
||||
switch o.Op {
|
||||
case opcode.AndAnd:
|
||||
boolVal = leftVal != zeroI64 && rightVal != zeroI64
|
||||
case opcode.OrOr:
|
||||
boolVal = leftVal != zeroI64 || rightVal != zeroI64
|
||||
case opcode.LogicXor:
|
||||
boolVal = (leftVal == zeroI64 && rightVal != zeroI64) || (leftVal != zeroI64 && rightVal == zeroI64)
|
||||
default:
|
||||
panic("should never happen")
|
||||
}
|
||||
if boolVal {
|
||||
o.SetValue(oneI64)
|
||||
} else {
|
||||
o.SetValue(zeroI64)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) handleComparisonOp(o *ast.BinaryOperationExpr) bool {
|
||||
a, b := types.Coerce(o.L.GetValue(), o.R.GetValue())
|
||||
if types.IsNil(a) || types.IsNil(b) {
|
||||
// for <=>, if a and b are both nil, return true.
|
||||
// if a or b is nil, return false.
|
||||
if o.Op == opcode.NullEQ {
|
||||
if types.IsNil(a) || types.IsNil(b) {
|
||||
o.SetValue(oneI64)
|
||||
} else {
|
||||
o.SetValue(zeroI64)
|
||||
}
|
||||
} else {
|
||||
o.SetValue(nil)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
n, err := types.Compare(a, b)
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
r, err := getCompResult(o.Op, n)
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
if r {
|
||||
o.SetValue(oneI64)
|
||||
} else {
|
||||
o.SetValue(zeroI64)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func getCompResult(op opcode.Op, value int) (bool, error) {
|
||||
switch op {
|
||||
case opcode.LT:
|
||||
return value < 0, nil
|
||||
case opcode.LE:
|
||||
return value <= 0, nil
|
||||
case opcode.GE:
|
||||
return value >= 0, nil
|
||||
case opcode.GT:
|
||||
return value > 0, nil
|
||||
case opcode.EQ:
|
||||
return value == 0, nil
|
||||
case opcode.NE:
|
||||
return value != 0, nil
|
||||
case opcode.NullEQ:
|
||||
return value == 0, nil
|
||||
default:
|
||||
return false, errors.Errorf("invalid op %v in comparision operation", op)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Evaluator) handleBitOp(o *ast.BinaryOperationExpr) bool {
|
||||
a, b := types.Coerce(o.L.GetValue(), o.R.GetValue())
|
||||
|
||||
if types.IsNil(a) || types.IsNil(b) {
|
||||
o.SetValue(nil)
|
||||
return true
|
||||
}
|
||||
|
||||
x, err := types.ToInt64(a)
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
y, err := types.ToInt64(b)
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
// use a int64 for bit operator, return uint64
|
||||
switch o.Op {
|
||||
case opcode.And:
|
||||
o.SetValue(uint64(x & y))
|
||||
case opcode.Or:
|
||||
o.SetValue(uint64(x | y))
|
||||
case opcode.Xor:
|
||||
o.SetValue(uint64(x ^ y))
|
||||
case opcode.RightShift:
|
||||
o.SetValue(uint64(x) >> uint64(y))
|
||||
case opcode.LeftShift:
|
||||
o.SetValue(uint64(x) << uint64(y))
|
||||
default:
|
||||
e.err = errors.Errorf("invalid op %v in bit operation", o.Op)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *Evaluator) handleArithmeticOp(o *ast.BinaryOperationExpr) bool {
|
||||
a, err := coerceArithmetic(o.L.GetValue())
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
|
||||
b, err := coerceArithmetic(o.R.GetValue())
|
||||
if err != nil {
|
||||
e.err = err
|
||||
return false
|
||||
}
|
||||
a, b = types.Coerce(a, b)
|
||||
|
||||
if a == nil || b == nil {
|
||||
// TODO: for <=>, if a and b are both nil, return true
|
||||
o.SetValue(nil)
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: support logic division DIV
|
||||
var result interface{}
|
||||
switch o.Op {
|
||||
case opcode.Plus:
|
||||
result, e.err = computePlus(a, b)
|
||||
case opcode.Minus:
|
||||
result, e.err = computeMinus(a, b)
|
||||
case opcode.Mul:
|
||||
result, e.err = computeMul(a, b)
|
||||
case opcode.Div:
|
||||
result, e.err = computeDiv(a, b)
|
||||
case opcode.Mod:
|
||||
result, e.err = computeMod(a, b)
|
||||
case opcode.IntDiv:
|
||||
result, e.err = computeIntDiv(a, b)
|
||||
default:
|
||||
e.err = errors.Errorf("invalid op %v in arithmetic operation", o.Op)
|
||||
return false
|
||||
}
|
||||
o.SetValue(result)
|
||||
return e.err == nil
|
||||
}
|
||||
|
||||
func computePlus(a, b interface{}) (interface{}, error) {
|
||||
switch x := a.(type) {
|
||||
case int64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
return types.AddInt64(x, y)
|
||||
case uint64:
|
||||
return types.AddInteger(y, x)
|
||||
}
|
||||
case uint64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
return types.AddInteger(x, y)
|
||||
case uint64:
|
||||
return types.AddUint64(x, y)
|
||||
}
|
||||
case float64:
|
||||
switch y := b.(type) {
|
||||
case float64:
|
||||
return x + y, nil
|
||||
}
|
||||
case mysql.Decimal:
|
||||
switch y := b.(type) {
|
||||
case mysql.Decimal:
|
||||
return x.Add(y), nil
|
||||
}
|
||||
}
|
||||
return types.InvOp2(a, b, opcode.Plus)
|
||||
}
|
||||
|
||||
func computeMinus(a, b interface{}) (interface{}, error) {
|
||||
switch x := a.(type) {
|
||||
case int64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
return types.SubInt64(x, y)
|
||||
case uint64:
|
||||
return types.SubIntWithUint(x, y)
|
||||
}
|
||||
case uint64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
return types.SubUintWithInt(x, y)
|
||||
case uint64:
|
||||
return types.SubUint64(x, y)
|
||||
}
|
||||
case float64:
|
||||
switch y := b.(type) {
|
||||
case float64:
|
||||
return x - y, nil
|
||||
}
|
||||
case mysql.Decimal:
|
||||
switch y := b.(type) {
|
||||
case mysql.Decimal:
|
||||
return x.Sub(y), nil
|
||||
}
|
||||
}
|
||||
|
||||
return types.InvOp2(a, b, opcode.Minus)
|
||||
}
|
||||
|
||||
func computeMul(a, b interface{}) (interface{}, error) {
|
||||
switch x := a.(type) {
|
||||
case int64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
return types.MulInt64(x, y)
|
||||
case uint64:
|
||||
return types.MulInteger(y, x)
|
||||
}
|
||||
case uint64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
return types.MulInteger(x, y)
|
||||
case uint64:
|
||||
return types.MulUint64(x, y)
|
||||
}
|
||||
case float64:
|
||||
switch y := b.(type) {
|
||||
case float64:
|
||||
return x * y, nil
|
||||
}
|
||||
case mysql.Decimal:
|
||||
switch y := b.(type) {
|
||||
case mysql.Decimal:
|
||||
return x.Mul(y), nil
|
||||
}
|
||||
}
|
||||
|
||||
return types.InvOp2(a, b, opcode.Mul)
|
||||
}
|
||||
|
||||
func computeDiv(a, b interface{}) (interface{}, error) {
|
||||
// MySQL support integer divison Div and division operator /
|
||||
// we use opcode.Div for division operator and will use another for integer division later.
|
||||
// for division operator, we will use float64 for calculation.
|
||||
switch x := a.(type) {
|
||||
case float64:
|
||||
y, err := types.ToFloat64(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return x / y, nil
|
||||
default:
|
||||
// the scale of the result is the scale of the first operand plus
|
||||
// the value of the div_precision_increment system variable (which is 4 by default)
|
||||
// we will use 4 here
|
||||
|
||||
xa, err := types.ToDecimal(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
xb, err := types.ToDecimal(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if f, _ := xb.Float64(); f == 0 {
|
||||
// division by zero return null
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return xa.Div(xb), nil
|
||||
}
|
||||
}
|
||||
|
||||
func computeMod(a, b interface{}) (interface{}, error) {
|
||||
switch x := a.(type) {
|
||||
case int64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return x % y, nil
|
||||
case uint64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
} else if x < 0 {
|
||||
// first is int64, return int64.
|
||||
return -int64(uint64(-x) % y), nil
|
||||
}
|
||||
return int64(uint64(x) % y), nil
|
||||
}
|
||||
case uint64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
} else if y < 0 {
|
||||
// first is uint64, return uint64.
|
||||
return uint64(x % uint64(-y)), nil
|
||||
}
|
||||
return x % uint64(y), nil
|
||||
case uint64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return x % y, nil
|
||||
}
|
||||
case float64:
|
||||
switch y := b.(type) {
|
||||
case float64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return math.Mod(x, y), nil
|
||||
}
|
||||
case mysql.Decimal:
|
||||
switch y := b.(type) {
|
||||
case mysql.Decimal:
|
||||
xf, _ := x.Float64()
|
||||
yf, _ := y.Float64()
|
||||
if yf == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return math.Mod(xf, yf), nil
|
||||
}
|
||||
}
|
||||
|
||||
return types.InvOp2(a, b, opcode.Mod)
|
||||
}
|
||||
|
||||
func computeIntDiv(a, b interface{}) (interface{}, error) {
|
||||
switch x := a.(type) {
|
||||
case int64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return types.DivInt64(x, y)
|
||||
case uint64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return types.DivIntWithUint(x, y)
|
||||
}
|
||||
case uint64:
|
||||
switch y := b.(type) {
|
||||
case int64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return types.DivUintWithInt(x, y)
|
||||
case uint64:
|
||||
if y == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return x / y, nil
|
||||
}
|
||||
}
|
||||
|
||||
// if any is none integer, use decimal to calculate
|
||||
x, err := types.ToDecimal(a)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
y, err := types.ToDecimal(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f, _ := y.Float64(); f == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return x.Div(y).IntPart(), nil
|
||||
}
|
||||
|
||||
func coerceArithmetic(a interface{}) (interface{}, error) {
|
||||
switch x := a.(type) {
|
||||
case string:
|
||||
// MySQL will convert string to float for arithmetic operation
|
||||
f, err := types.StrToFloat(x)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, err
|
||||
case mysql.Time:
|
||||
// if time has no precision, return int64
|
||||
v := x.ToNumber()
|
||||
if x.Fsp == 0 {
|
||||
return v.IntPart(), nil
|
||||
}
|
||||
return v, nil
|
||||
case mysql.Duration:
|
||||
// if duration has no precision, return int64
|
||||
v := x.ToNumber()
|
||||
if x.Fsp == 0 {
|
||||
return v.IntPart(), nil
|
||||
}
|
||||
return v, nil
|
||||
case []byte:
|
||||
// []byte is the same as string, converted to float for arithmetic operator.
|
||||
f, err := types.StrToFloat(string(x))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, err
|
||||
case mysql.Hex:
|
||||
return x.ToNumber(), nil
|
||||
case mysql.Bit:
|
||||
return x.ToNumber(), nil
|
||||
case mysql.Enum:
|
||||
return x.ToNumber(), nil
|
||||
case mysql.Set:
|
||||
return x.ToNumber(), nil
|
||||
default:
|
||||
return x, nil
|
||||
}
|
||||
}
|
||||
441
optimizer/infobinder.go
Normal file
441
optimizer/infobinder.go
Normal file
@ -0,0 +1,441 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/infoschema"
|
||||
"github.com/pingcap/tidb/model"
|
||||
)
|
||||
|
||||
// InfoBinder binds schema information for table name and column name and set result fields
|
||||
// for ResetSetNode.
|
||||
// We need to know which table a table name refers to, which column a column name refers to.
|
||||
//
|
||||
// In general, a reference can only refer to information that are available for it.
|
||||
// So children elements are visited in the order that previous elements make information
|
||||
// available for following elements.
|
||||
//
|
||||
// During visiting, information are collected and stored in binderContext.
|
||||
// When we enter a sub query, a new binderContext is pushed to the contextStack, so sub query
|
||||
// information can overwrite outer query information. When we look up for a column reference,
|
||||
// we look up from top to bottom in the contextStack.
|
||||
type InfoBinder struct {
|
||||
Info infoschema.InfoSchema
|
||||
DefaultSchema model.CIStr
|
||||
Err error
|
||||
|
||||
contextStack []*binderContext
|
||||
}
|
||||
|
||||
// binderContext stores information that table name and column name
|
||||
// can be bind to.
|
||||
type binderContext struct {
|
||||
/* For Select Statement. */
|
||||
// table map to lookup and check table name conflict.
|
||||
tableMap map[string]int
|
||||
// tableSources collected in from clause.
|
||||
tables []*ast.TableSource
|
||||
// result fields collected in select field list.
|
||||
fieldList []*ast.ResultField
|
||||
// result fields collected in group by clause.
|
||||
groupBy []*ast.ResultField
|
||||
|
||||
// The join node stack is used by on condition to find out
|
||||
// available tables to reference. On condition can only
|
||||
// refer to tables involved in current join.
|
||||
joinNodeStack []*ast.Join
|
||||
|
||||
// When visiting TableRefs, tables in this context are not available
|
||||
// because it is being collected.
|
||||
inTableRefs bool
|
||||
// When visiting on conditon only tables in current join node are available.
|
||||
inOnCondition bool
|
||||
// When visiting field list, fieldList in this context are not available.
|
||||
inFieldList bool
|
||||
// When visiting group by, groupBy fields are not available.
|
||||
inGroupBy bool
|
||||
// When visiting having, only fieldList and groupBy fields are available.
|
||||
inHaving bool
|
||||
}
|
||||
|
||||
// currentContext gets the current binder context.
|
||||
func (sb *InfoBinder) currentContext() *binderContext {
|
||||
stackLen := len(sb.contextStack)
|
||||
if stackLen == 0 {
|
||||
return nil
|
||||
}
|
||||
return sb.contextStack[stackLen-1]
|
||||
}
|
||||
|
||||
// pushContext is called when we enter a statement.
|
||||
func (sb *InfoBinder) pushContext() {
|
||||
sb.contextStack = append(sb.contextStack, &binderContext{
|
||||
tableMap: map[string]int{},
|
||||
})
|
||||
}
|
||||
|
||||
// popContext is called when we leave a statement.
|
||||
func (sb *InfoBinder) popContext() {
|
||||
sb.contextStack = sb.contextStack[:len(sb.contextStack)-1]
|
||||
}
|
||||
|
||||
// pushJoin is called when we enter a join node.
|
||||
func (sb *InfoBinder) pushJoin(j *ast.Join) {
|
||||
ctx := sb.currentContext()
|
||||
ctx.joinNodeStack = append(ctx.joinNodeStack, j)
|
||||
}
|
||||
|
||||
// popJoin is called when we leave a join node.
|
||||
func (sb *InfoBinder) popJoin() {
|
||||
ctx := sb.currentContext()
|
||||
ctx.joinNodeStack = ctx.joinNodeStack[:len(ctx.joinNodeStack)-1]
|
||||
}
|
||||
|
||||
// Enter implements ast.Visitor interface.
|
||||
func (sb *InfoBinder) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) {
|
||||
switch v := inNode.(type) {
|
||||
case *ast.SelectStmt:
|
||||
sb.pushContext()
|
||||
case *ast.TableRefsClause:
|
||||
sb.currentContext().inTableRefs = true
|
||||
case *ast.Join:
|
||||
sb.pushJoin(v)
|
||||
case *ast.OnCondition:
|
||||
sb.currentContext().inOnCondition = true
|
||||
case *ast.FieldList:
|
||||
sb.currentContext().inFieldList = true
|
||||
case *ast.GroupByClause:
|
||||
sb.currentContext().inGroupBy = true
|
||||
case *ast.HavingClause:
|
||||
sb.currentContext().inHaving = true
|
||||
case *ast.InsertStmt:
|
||||
sb.pushContext()
|
||||
case *ast.DeleteStmt:
|
||||
sb.pushContext()
|
||||
case *ast.UpdateStmt:
|
||||
sb.pushContext()
|
||||
}
|
||||
return inNode, false
|
||||
}
|
||||
|
||||
// Leave implements ast.Visitor interface.
|
||||
func (sb *InfoBinder) Leave(inNode ast.Node) (node ast.Node, ok bool) {
|
||||
switch v := inNode.(type) {
|
||||
case *ast.TableName:
|
||||
sb.handleTableName(v)
|
||||
case *ast.ColumnName:
|
||||
sb.handleColumnName(v)
|
||||
case *ast.TableSource:
|
||||
sb.handleTableSource(v)
|
||||
case *ast.OnCondition:
|
||||
sb.currentContext().inOnCondition = false
|
||||
case *ast.Join:
|
||||
sb.handleJoin(v)
|
||||
sb.popJoin()
|
||||
case *ast.TableRefsClause:
|
||||
sb.currentContext().inTableRefs = false
|
||||
case *ast.FieldList:
|
||||
sb.handleFieldList(v)
|
||||
sb.currentContext().inFieldList = false
|
||||
case *ast.GroupByClause:
|
||||
sb.currentContext().inGroupBy = false
|
||||
case *ast.HavingClause:
|
||||
sb.currentContext().inHaving = false
|
||||
case *ast.SelectStmt:
|
||||
v.SetResultFields(sb.currentContext().fieldList)
|
||||
sb.popContext()
|
||||
case *ast.InsertStmt:
|
||||
sb.popContext()
|
||||
case *ast.DeleteStmt:
|
||||
sb.popContext()
|
||||
case *ast.UpdateStmt:
|
||||
sb.popContext()
|
||||
}
|
||||
return inNode, sb.Err == nil
|
||||
}
|
||||
|
||||
// handleTableName looks up and bind the schema information for table name
|
||||
// and set result fields for table name.
|
||||
func (sb *InfoBinder) handleTableName(tn *ast.TableName) {
|
||||
if tn.Schema.L == "" {
|
||||
tn.Schema = sb.DefaultSchema
|
||||
}
|
||||
table, err := sb.Info.TableByName(tn.Schema, tn.Name)
|
||||
if err != nil {
|
||||
sb.Err = err
|
||||
return
|
||||
}
|
||||
tn.TableInfo = table.Meta()
|
||||
dbInfo, _ := sb.Info.SchemaByName(tn.Schema)
|
||||
tn.DBInfo = dbInfo
|
||||
|
||||
rfs := make([]*ast.ResultField, len(tn.TableInfo.Columns))
|
||||
for i, v := range tn.TableInfo.Columns {
|
||||
rfs[i] = &ast.ResultField{
|
||||
Column: v,
|
||||
Table: tn.TableInfo,
|
||||
DBName: tn.Schema,
|
||||
}
|
||||
}
|
||||
tn.SetResultFields(rfs)
|
||||
return
|
||||
}
|
||||
|
||||
// handleTableSources checks name duplication
|
||||
// and puts the table source in current binderContext.
|
||||
func (sb *InfoBinder) handleTableSource(ts *ast.TableSource) {
|
||||
for _, v := range ts.GetResultFields() {
|
||||
v.TableAsName = ts.AsName
|
||||
}
|
||||
var name string
|
||||
if ts.AsName.L != "" {
|
||||
name = ts.AsName.L
|
||||
} else {
|
||||
tableName := ts.Source.(*ast.TableName)
|
||||
name = sb.tableUniqueName(tableName.Schema, tableName.Name)
|
||||
}
|
||||
ctx := sb.currentContext()
|
||||
if _, ok := ctx.tableMap[name]; ok {
|
||||
sb.Err = errors.Errorf("duplicated table/alias name %s", name)
|
||||
return
|
||||
}
|
||||
ctx.tableMap[name] = len(ctx.tables)
|
||||
ctx.tables = append(ctx.tables, ts)
|
||||
return
|
||||
}
|
||||
|
||||
// handleJoin sets result fields for join.
|
||||
func (sb *InfoBinder) handleJoin(j *ast.Join) {
|
||||
if j.Right == nil {
|
||||
j.SetResultFields(j.Left.GetResultFields())
|
||||
return
|
||||
}
|
||||
leftLen := len(j.Left.GetResultFields())
|
||||
rightLen := len(j.Right.GetResultFields())
|
||||
rfs := make([]*ast.ResultField, leftLen+rightLen)
|
||||
copy(rfs, j.Left.GetResultFields())
|
||||
copy(rfs[leftLen:], j.Right.GetResultFields())
|
||||
j.SetResultFields(rfs)
|
||||
}
|
||||
|
||||
// handleColumnName looks up and binds schema information to
|
||||
// the column name.
|
||||
func (sb *InfoBinder) handleColumnName(cn *ast.ColumnName) {
|
||||
ctx := sb.currentContext()
|
||||
if ctx.inOnCondition {
|
||||
// In on condition, only tables within current join is available.
|
||||
sb.bindColumnNameInOnCondition(cn)
|
||||
return
|
||||
}
|
||||
|
||||
// Try to bind the column name form top to bottom in the context stack.
|
||||
for i := len(sb.contextStack) - 1; i >= 0; i-- {
|
||||
if sb.bindColumnNameInContext(sb.contextStack[i], cn) {
|
||||
// Column is already bound or encountered an error.
|
||||
return
|
||||
}
|
||||
}
|
||||
sb.Err = errors.Errorf("Unknown column %s", cn.Name.L)
|
||||
}
|
||||
|
||||
// bindColumnNameInContext looks up and binds schema information for a column with the ctx.
|
||||
func (sb *InfoBinder) bindColumnNameInContext(ctx *binderContext, cn *ast.ColumnName) (done bool) {
|
||||
if cn.Table.L == "" {
|
||||
// If qualified table name is not specified in column name, the column name may be ambiguous,
|
||||
// We need to iterate over all tables and
|
||||
}
|
||||
|
||||
if ctx.inTableRefs {
|
||||
// In TableRefsClause, column reference only in join on condition which is handled before.
|
||||
return false
|
||||
}
|
||||
if ctx.inFieldList {
|
||||
// only bind column using tables.
|
||||
return sb.bindColumnInTableSources(cn, ctx.tables)
|
||||
}
|
||||
if ctx.inGroupBy {
|
||||
// field list first, then tables.
|
||||
if sb.bindColumnInResultFields(cn, ctx.fieldList) {
|
||||
return true
|
||||
}
|
||||
return sb.bindColumnInTableSources(cn, ctx.tables)
|
||||
}
|
||||
// column name in other places can be looked up in the same order.
|
||||
if sb.bindColumnInResultFields(cn, ctx.groupBy) {
|
||||
return true
|
||||
}
|
||||
if sb.bindColumnInResultFields(cn, ctx.fieldList) {
|
||||
return true
|
||||
}
|
||||
|
||||
// tables is not available for having clause.
|
||||
if !ctx.inHaving {
|
||||
return sb.bindColumnInTableSources(cn, ctx.tables)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// bindColumnNameInOnCondition looks up for column name in current join, and
|
||||
// binds the schema information.
|
||||
func (sb *InfoBinder) bindColumnNameInOnCondition(cn *ast.ColumnName) {
|
||||
ctx := sb.currentContext()
|
||||
join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1]
|
||||
tableSources := appendTableSources(nil, join)
|
||||
if !sb.bindColumnInTableSources(cn, tableSources) {
|
||||
sb.Err = errors.Errorf("unkown column name %s", cn.Name.O)
|
||||
}
|
||||
}
|
||||
|
||||
func (sb *InfoBinder) bindColumnInTableSources(cn *ast.ColumnName, tableSources []*ast.TableSource) (done bool) {
|
||||
var matchedResultField *ast.ResultField
|
||||
if cn.Table.L != "" {
|
||||
var matchedTable ast.ResultSetNode
|
||||
for _, ts := range tableSources {
|
||||
if cn.Table.L == ts.AsName.L {
|
||||
// different table name.
|
||||
matchedTable = ts
|
||||
break
|
||||
}
|
||||
if tn, ok := ts.Source.(*ast.TableName); ok {
|
||||
if cn.Table.L == tn.Name.L {
|
||||
matchedTable = ts
|
||||
}
|
||||
}
|
||||
}
|
||||
if matchedTable != nil {
|
||||
resultFields := matchedTable.GetResultFields()
|
||||
for _, rf := range resultFields {
|
||||
if rf.ColumnAsName.L == cn.Name.L || rf.Column.Name.L == cn.Name.L {
|
||||
// bind column.
|
||||
matchedResultField = rf
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for _, ts := range tableSources {
|
||||
rfs := ts.GetResultFields()
|
||||
for _, rf := range rfs {
|
||||
matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L
|
||||
matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L
|
||||
if matchAsName || matchColumnName {
|
||||
if matchedResultField != nil {
|
||||
sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O)
|
||||
return true
|
||||
}
|
||||
matchedResultField = rf
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if matchedResultField != nil {
|
||||
// bind column.
|
||||
cn.ColumnInfo = matchedResultField.Column
|
||||
cn.TableInfo = matchedResultField.Table
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (sb *InfoBinder) bindColumnInResultFields(cn *ast.ColumnName, rfs []*ast.ResultField) bool {
|
||||
var matchedResultField *ast.ResultField
|
||||
for _, rf := range rfs {
|
||||
matchAsName := rf.ColumnAsName.L != "" && rf.ColumnAsName.L == cn.Name.L
|
||||
matchColumnName := rf.ColumnAsName.L == "" && rf.Column.Name.L == cn.Name.L
|
||||
if matchAsName || matchColumnName {
|
||||
if matchedResultField != nil {
|
||||
sb.Err = errors.Errorf("column %s is ambiguous.", cn.Name.O)
|
||||
return false
|
||||
}
|
||||
matchedResultField = rf
|
||||
}
|
||||
}
|
||||
if matchedResultField != nil {
|
||||
// bind column.
|
||||
cn.ColumnInfo = matchedResultField.Column
|
||||
cn.TableInfo = matchedResultField.Table
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// handleFieldList expands wild card field and set fieldList in current context.
|
||||
func (sb *InfoBinder) handleFieldList(fieldList *ast.FieldList) {
|
||||
var resultFields []*ast.ResultField
|
||||
for _, v := range fieldList.Fields {
|
||||
resultFields = append(resultFields, sb.createResultFields(v)...)
|
||||
}
|
||||
sb.currentContext().fieldList = resultFields
|
||||
}
|
||||
|
||||
// createResultFields creates result field list for a single select field.
|
||||
func (sb *InfoBinder) createResultFields(field *ast.SelectField) (rfs []*ast.ResultField) {
|
||||
ctx := sb.currentContext()
|
||||
if field.WildCard != nil {
|
||||
if len(ctx.tables) == 0 {
|
||||
sb.Err = errors.Errorf("No table used.")
|
||||
return
|
||||
}
|
||||
if field.WildCard.Table.L == "" {
|
||||
for _, v := range ctx.tables {
|
||||
rfs = append(rfs, v.GetResultFields()...)
|
||||
}
|
||||
} else {
|
||||
name := sb.tableUniqueName(field.WildCard.Schema, field.WildCard.Table)
|
||||
tableIdx, ok := ctx.tableMap[name]
|
||||
if !ok {
|
||||
sb.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O)
|
||||
}
|
||||
rfs = ctx.tables[tableIdx].GetResultFields()
|
||||
}
|
||||
return
|
||||
}
|
||||
// The column is visited before so it must has been bound already.
|
||||
rf := &ast.ResultField{ColumnAsName: field.AsName}
|
||||
switch v := field.Expr.(type) {
|
||||
case *ast.ColumnNameExpr:
|
||||
rf.Column = v.Name.ColumnInfo
|
||||
rf.Table = v.Name.TableInfo
|
||||
rf.DBName = v.Name.Schema
|
||||
default:
|
||||
if field.AsName.L == "" {
|
||||
rf.ColumnAsName.L = field.Expr.Text()
|
||||
rf.ColumnAsName.O = rf.ColumnAsName.L
|
||||
}
|
||||
}
|
||||
rfs = append(rfs, rf)
|
||||
return
|
||||
}
|
||||
|
||||
func appendTableSources(in []*ast.TableSource, resultSetNode ast.ResultSetNode) (out []*ast.TableSource) {
|
||||
switch v := resultSetNode.(type) {
|
||||
case *ast.TableSource:
|
||||
out = append(in, v)
|
||||
case *ast.Join:
|
||||
out = appendTableSources(in, v.Left)
|
||||
if v.Right != nil {
|
||||
out = appendTableSources(out, v.Right)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (sb *InfoBinder) tableUniqueName(schema, table model.CIStr) string {
|
||||
if schema.L != "" && schema.L != sb.DefaultSchema.L {
|
||||
return schema.L + "." + table.L
|
||||
}
|
||||
return table.L
|
||||
}
|
||||
82
optimizer/infobinder_test.go
Normal file
82
optimizer/infobinder_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/optimizer"
|
||||
"github.com/pingcap/tidb/parser"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/util/testkit"
|
||||
)
|
||||
|
||||
func TestT(t *testing.T) {
|
||||
TestingT(t)
|
||||
}
|
||||
|
||||
var _ = Suite(&testInfoBinderSuite{})
|
||||
|
||||
type testInfoBinderSuite struct {
|
||||
}
|
||||
|
||||
type binderVerifier struct {
|
||||
c *C
|
||||
}
|
||||
|
||||
func (bv *binderVerifier) Enter(node ast.Node) (ast.Node, bool) {
|
||||
return node, false
|
||||
}
|
||||
|
||||
func (bv *binderVerifier) Leave(in ast.Node) (out ast.Node, ok bool) {
|
||||
switch v := in.(type) {
|
||||
case *ast.ColumnName:
|
||||
bv.c.Assert(v.ColumnInfo, NotNil)
|
||||
case *ast.TableName:
|
||||
bv.c.Assert(v.TableInfo, NotNil)
|
||||
}
|
||||
return in, true
|
||||
}
|
||||
|
||||
func (ts *testInfoBinderSuite) TestInfoBinder(c *C) {
|
||||
store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory)
|
||||
c.Assert(err, IsNil)
|
||||
defer store.Close()
|
||||
testKit := testkit.NewTestKit(c, store)
|
||||
testKit.MustExec("use test")
|
||||
testKit.MustExec("create table t (c1 int, c2 int)")
|
||||
domain := sessionctx.GetDomain(testKit.Se.(context.Context))
|
||||
|
||||
src := "SELECT c1 from t"
|
||||
l := parser.NewLexer(src)
|
||||
c.Assert(parser.YYParse(l), Equals, 0)
|
||||
stmts := l.Stmts()
|
||||
c.Assert(len(stmts), Equals, 1)
|
||||
v := &optimizer.InfoBinder{
|
||||
Info: domain.InfoSchema(),
|
||||
DefaultSchema: model.NewCIStr("test"),
|
||||
}
|
||||
selectStmt := stmts[0].(*ast.SelectStmt)
|
||||
selectStmt.Accept(v)
|
||||
|
||||
verifier := &binderVerifier{
|
||||
c: c,
|
||||
}
|
||||
selectStmt.Accept(verifier)
|
||||
}
|
||||
@ -13,20 +13,18 @@
|
||||
|
||||
package optimizer
|
||||
|
||||
import (
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/stmt"
|
||||
)
|
||||
import "github.com/pingcap/tidb/ast"
|
||||
|
||||
// Compile compiles a ast.Node into a executable statement.
|
||||
func Compile(node ast.Node) (stmt.Statement, error) {
|
||||
switch v := node.(type) {
|
||||
case *ast.SetStmt:
|
||||
return compileSet(v)
|
||||
}
|
||||
return nil, nil
|
||||
// typeComputer is an ast Visitor that
|
||||
// computes result type for ast.ExprNode.
|
||||
type typeComputer struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func compileSet(aset *ast.SetStmt) (stmt.Statement, error) {
|
||||
return nil, nil
|
||||
func (v *typeComputer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
|
||||
return in, false
|
||||
}
|
||||
|
||||
func (v *typeComputer) Leave(in ast.Node) (out ast.Node, ok bool) {
|
||||
return in, true
|
||||
}
|
||||
30
optimizer/validator.go
Normal file
30
optimizer/validator.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package optimizer
|
||||
|
||||
import "github.com/pingcap/tidb/ast"
|
||||
|
||||
// validator is an ast.Visitor that validates
|
||||
// ast parsed from parser.
|
||||
type validator struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
|
||||
return in, false
|
||||
}
|
||||
|
||||
func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
|
||||
return in, true
|
||||
}
|
||||
2147
parser/parser.y
2147
parser/parser.y
File diff suppressed because it is too large
Load Diff
@ -18,9 +18,7 @@ import (
|
||||
"testing"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/expression/subquery"
|
||||
"github.com/pingcap/tidb/stmt/stmts"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
)
|
||||
|
||||
func TestT(t *testing.T) {
|
||||
@ -32,23 +30,6 @@ var _ = Suite(&testParserSuite{})
|
||||
type testParserSuite struct {
|
||||
}
|
||||
|
||||
func (s *testParserSuite) TestOriginText(c *C) {
|
||||
src := `SELECT stuff.id
|
||||
FROM stuff
|
||||
WHERE stuff.value >= ALL (SELECT stuff.value
|
||||
FROM stuff)`
|
||||
|
||||
l := NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
node := l.Stmts()[0].(*stmts.SelectStmt)
|
||||
sq := node.Where.Expr.(*expression.CompareSubQuery).R
|
||||
c.Assert(sq, NotNil)
|
||||
subsel := sq.(*subquery.SubQuery)
|
||||
c.Assert(subsel.Stmt.OriginText(), Equals,
|
||||
`SELECT stuff.value
|
||||
FROM stuff`)
|
||||
}
|
||||
|
||||
func (s *testParserSuite) TestSimple(c *C) {
|
||||
// Testcase for unreserved keywords
|
||||
unreservedKws := []string{
|
||||
@ -70,15 +51,12 @@ func (s *testParserSuite) TestSimple(c *C) {
|
||||
// Testcase for prepared statement
|
||||
src := "SELECT id+?, id+? from t;"
|
||||
l := NewLexer(src)
|
||||
l.SetPrepare()
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
c.Assert(len(l.ParamList), Equals, 2)
|
||||
c.Assert(len(l.Stmts()), Equals, 1)
|
||||
|
||||
// Testcase for -- Comment and unary -- operator
|
||||
src = "CREATE TABLE foo (a SMALLINT UNSIGNED, b INT UNSIGNED); -- foo\nSelect --1 from foo;"
|
||||
l = NewLexer(src)
|
||||
l.SetPrepare()
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
c.Assert(len(l.Stmts()), Equals, 2)
|
||||
|
||||
@ -87,11 +65,12 @@ func (s *testParserSuite) TestSimple(c *C) {
|
||||
l = NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
st := l.Stmts()[0]
|
||||
ss, ok := st.(*stmts.SelectStmt)
|
||||
ss, ok := st.(*ast.SelectStmt)
|
||||
c.Assert(ok, IsTrue)
|
||||
cv, ok := ss.Fields[0].Expr.(*expression.FunctionCast)
|
||||
c.Assert(len(ss.Fields.Fields), Equals, 1)
|
||||
cv, ok := ss.Fields.Fields[0].Expr.(*ast.FuncCastExpr)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(cv.FunctionType, Equals, expression.ConvertFunction)
|
||||
c.Assert(cv.FunctionType, Equals, ast.CastConvertFunction)
|
||||
|
||||
// For query start with comment
|
||||
srcs := []string{
|
||||
@ -105,7 +84,7 @@ func (s *testParserSuite) TestSimple(c *C) {
|
||||
l = NewLexer(src)
|
||||
c.Assert(yyParse(l), Equals, 0)
|
||||
st = l.Stmts()[0]
|
||||
ss, ok = st.(*stmts.SelectStmt)
|
||||
ss, ok = st.(*ast.SelectStmt)
|
||||
c.Assert(ok, IsTrue)
|
||||
}
|
||||
}
|
||||
@ -172,9 +151,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
|
||||
{"REPLACE INTO foo () VALUES ()", true},
|
||||
{"REPLACE INTO foo VALUE ()", true},
|
||||
// 40
|
||||
{`SELECT stuff.id
|
||||
FROM stuff
|
||||
WHERE stuff.value >= ALL (SELECT stuff.value
|
||||
{`SELECT stuff.id
|
||||
FROM stuff
|
||||
WHERE stuff.value >= ALL (SELECT stuff.value
|
||||
FROM stuff)`, true},
|
||||
{"BEGIN", true},
|
||||
{"START TRANSACTION", true},
|
||||
|
||||
@ -22,25 +22,25 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strconv"
|
||||
"unicode"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/util/stringutil"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"unicode"
|
||||
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/stringutil"
|
||||
)
|
||||
|
||||
type lexer struct {
|
||||
c int
|
||||
col int
|
||||
errs []error
|
||||
expr expression.Expression
|
||||
expr ast.ExprNode
|
||||
i int
|
||||
inj int
|
||||
lcol int
|
||||
line int
|
||||
list []interface{}
|
||||
list []ast.StmtNode
|
||||
ncol int
|
||||
nline int
|
||||
sc int
|
||||
@ -48,8 +48,7 @@ type lexer struct {
|
||||
val []byte
|
||||
ungetBuf []byte
|
||||
root bool
|
||||
prepare bool
|
||||
ParamList []*expression.ParamMarker
|
||||
prepare bool
|
||||
stmtStartPos int
|
||||
stringLit []byte
|
||||
|
||||
@ -78,11 +77,11 @@ func (l *lexer) Errors() []error {
|
||||
return l.errs
|
||||
}
|
||||
|
||||
func (l *lexer) Stmts() []interface{}{
|
||||
func (l *lexer) Stmts() []ast.StmtNode {
|
||||
return l.list
|
||||
}
|
||||
|
||||
func (l *lexer) Expr() expression.Expression {
|
||||
func (l *lexer) Expr() ast.ExprNode {
|
||||
return l.expr
|
||||
}
|
||||
|
||||
@ -90,16 +89,16 @@ func (l *lexer) Inj() int {
|
||||
return l.inj
|
||||
}
|
||||
|
||||
func (l *lexer) SetInj(inj int) {
|
||||
l.inj = inj
|
||||
}
|
||||
|
||||
func (l *lexer) SetPrepare() {
|
||||
l.prepare = true
|
||||
l.prepare = true
|
||||
}
|
||||
|
||||
func (l *lexer) IsPrepare() bool {
|
||||
return l.prepare
|
||||
}
|
||||
|
||||
func (l *lexer) SetInj(inj int) {
|
||||
l.inj = inj
|
||||
return l.prepare
|
||||
}
|
||||
|
||||
func (l *lexer) Root() bool {
|
||||
@ -116,7 +115,17 @@ func (l *lexer) SetCharsetInfo(charset, collation string) {
|
||||
}
|
||||
|
||||
func (l *lexer) GetCharsetInfo() (string, string) {
|
||||
return l.charset, l.collation
|
||||
return l.charset, l.collation
|
||||
}
|
||||
|
||||
// The select statement is not at the end of the whole statement, if the last
|
||||
// field text was set from its offset to the end of the src string, update
|
||||
// the last field text.
|
||||
func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) {
|
||||
lastField := st.Fields.Fields[len(st.Fields.Fields)-1]
|
||||
if lastField.Offset + len(lastField.Text()) >= len(l.src)-1 {
|
||||
lastField.SetText(l.src[lastField.Offset:lastEnd])
|
||||
}
|
||||
}
|
||||
|
||||
func (l *lexer) unget(b byte) {
|
||||
@ -822,9 +831,9 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
|
||||
{repeat} lval.item = string(l.val)
|
||||
return repeat
|
||||
{regexp} return regexp
|
||||
{references} return references
|
||||
{replace} lval.item = string(l.val)
|
||||
return replace
|
||||
{references} return references
|
||||
{rlike} return rlike
|
||||
|
||||
{sys_var} lval.item = string(l.val)
|
||||
|
||||
@ -1,92 +0,0 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package parser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unicode"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
)
|
||||
|
||||
func tok2name(i int) string {
|
||||
if i == unicode.ReplacementChar {
|
||||
return "<?>"
|
||||
}
|
||||
|
||||
if i < 128 {
|
||||
return fmt.Sprintf("tok-'%c'", i)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("tok-%d", i)
|
||||
}
|
||||
|
||||
func (s *testParserSuite) TestScaner0(c *C) {
|
||||
table := []struct {
|
||||
src string
|
||||
tok, line, col, nline, ncol int
|
||||
val string
|
||||
}{
|
||||
{"a", identifier, 1, 1, 1, 2, "a"},
|
||||
{" a", identifier, 1, 2, 1, 3, "a"},
|
||||
{"a ", identifier, 1, 1, 1, 2, "a"},
|
||||
{" a ", identifier, 1, 2, 1, 3, "a"},
|
||||
{"\na", identifier, 2, 1, 2, 2, "a"},
|
||||
|
||||
{"a\n", identifier, 1, 1, 1, 2, "a"},
|
||||
{"\na\n", identifier, 2, 1, 2, 2, "a"},
|
||||
{"\n a", identifier, 2, 2, 2, 3, "a"},
|
||||
{"a \n", identifier, 1, 1, 1, 2, "a"},
|
||||
{"\n a \n", identifier, 2, 2, 2, 3, "a"},
|
||||
|
||||
{"ab", identifier, 1, 1, 1, 3, "ab"},
|
||||
{" ab", identifier, 1, 2, 1, 4, "ab"},
|
||||
{"ab ", identifier, 1, 1, 1, 3, "ab"},
|
||||
{" ab ", identifier, 1, 2, 1, 4, "ab"},
|
||||
{"\nab", identifier, 2, 1, 2, 3, "ab"},
|
||||
|
||||
{"ab\n", identifier, 1, 1, 1, 3, "ab"},
|
||||
{"\nab\n", identifier, 2, 1, 2, 3, "ab"},
|
||||
{"\n ab", identifier, 2, 2, 2, 4, "ab"},
|
||||
{"ab \n", identifier, 1, 1, 1, 3, "ab"},
|
||||
{"\n ab \n", identifier, 2, 2, 2, 4, "ab"},
|
||||
|
||||
{"c", identifier, 1, 1, 1, 2, "c"},
|
||||
{"cR", identifier, 1, 1, 1, 3, "cR"},
|
||||
{"cRe", identifier, 1, 1, 1, 4, "cRe"},
|
||||
{"cReA", identifier, 1, 1, 1, 5, "cReA"},
|
||||
{"cReAt", identifier, 1, 1, 1, 6, "cReAt"},
|
||||
|
||||
{"cReATe", create, 1, 1, 1, 7, "cReATe"},
|
||||
{"cReATeD", identifier, 1, 1, 1, 8, "cReATeD"},
|
||||
{"2", intLit, 1, 1, 1, 2, "2"},
|
||||
{"2.", floatLit, 1, 1, 1, 3, "2."},
|
||||
{"2.3", floatLit, 1, 1, 1, 4, "2.3"},
|
||||
}
|
||||
|
||||
lval := &yySymType{}
|
||||
for _, t := range table {
|
||||
l := NewLexer(t.src)
|
||||
tok := l.Lex(lval)
|
||||
nline, ncol := l.npos()
|
||||
val := string(l.val)
|
||||
|
||||
c.Assert(tok, Equals, t.tok)
|
||||
c.Assert(l.line, Equals, t.line)
|
||||
c.Assert(l.col, Equals, t.col)
|
||||
c.Assert(nline, Equals, t.nline)
|
||||
c.Assert(ncol, Equals, t.ncol)
|
||||
c.Assert(val, Equals, t.val)
|
||||
}
|
||||
}
|
||||
@ -62,7 +62,7 @@ func (s *testStmtSuite) TestDropTable(c *C) {
|
||||
}
|
||||
|
||||
func (s *testStmtSuite) TestDropIndex(c *C) {
|
||||
testSQL := "drop index if exists drop_index;"
|
||||
testSQL := "drop index if exists drop_index on t;"
|
||||
|
||||
stmtList, err := tidb.Compile(s.ctx, testSQL)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
@ -51,5 +51,5 @@ func (s *testStmtSuite) TestShow(c *C) {
|
||||
c.Assert(stmtList, HasLen, 1)
|
||||
testStmt, ok = stmtList[0].(*stmts.ShowStmt)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(testStmt.Pattern, NotNil)
|
||||
c.Assert(testStmt.Pattern, NotNil, Commentf("S: %s", testStmt.Text))
|
||||
}
|
||||
|
||||
21
tidb.go
21
tidb.go
@ -26,7 +26,6 @@ import (
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/ngaut/log"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/domain"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
@ -132,15 +131,12 @@ func Compile(ctx context.Context, src string) ([]stmt.Statement, error) {
|
||||
rawStmt := l.Stmts()
|
||||
stmts := make([]stmt.Statement, len(rawStmt))
|
||||
for i, v := range rawStmt {
|
||||
if node, ok := v.(ast.Node); ok {
|
||||
stm, err := optimizer.Compile(node)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
stmts[i] = stm
|
||||
} else {
|
||||
stmts[i] = v.(stmt.Statement)
|
||||
compiler := &optimizer.Compiler{}
|
||||
stm, err := compiler.Compile(v)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
stmts[i] = stm
|
||||
}
|
||||
return stmts, nil
|
||||
}
|
||||
@ -162,7 +158,12 @@ func CompilePrepare(ctx context.Context, src string) (stmt.Statement, []*express
|
||||
return nil, nil, nil
|
||||
}
|
||||
sm := sms[0]
|
||||
return sm.(stmt.Statement), l.ParamList, nil
|
||||
compiler := &optimizer.Compiler{}
|
||||
statement, err := compiler.Compile(sm)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
return statement, compiler.ParamMarkers(), nil
|
||||
}
|
||||
|
||||
func prepareStmt(ctx context.Context, sqlText string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) {
|
||||
|
||||
@ -640,6 +640,7 @@ func (s *testSessionSuite) TestSelectForUpdate(c *C) {
|
||||
// conflict
|
||||
mustExecSQL(c, se1, "begin")
|
||||
rs, err := exec(c, se1, "select * from t where c1=11 for update")
|
||||
c.Assert(err, IsNil)
|
||||
_, err = rs.Rows(-1, 0)
|
||||
|
||||
mustExecSQL(c, se2, "begin")
|
||||
|
||||
Reference in New Issue
Block a user