Files
tidb/bindinfo/bind.go

170 lines
5.3 KiB
Go

// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package bindinfo
import "github.com/pingcap/parser/ast"
// BindHint will add hints for originStmt according to hintedStmt' hints.
func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode {
switch x := originStmt.(type) {
case *ast.SelectStmt:
return selectBind(x, hintedStmt.(*ast.SelectStmt))
default:
return originStmt
}
}
func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt {
if hintedNode.TableHints != nil {
originalNode.TableHints = hintedNode.TableHints
}
if originalNode.From != nil {
originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join)
}
if originalNode.Where != nil {
originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode)
}
if originalNode.Having != nil {
originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr)
}
if originalNode.OrderBy != nil {
originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy)
}
if originalNode.Fields != nil {
origFields := originalNode.Fields.Fields
hintFields := hintedNode.Fields.Fields
for idx := range origFields {
origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr)
}
}
return originalNode
}
func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause {
for idx := 0; idx < len(originalNode.Items); idx++ {
originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr)
}
return originalNode
}
func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode {
switch v := originalNode.(type) {
case *ast.SubqueryExpr:
if v.Query != nil {
v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query)
}
case *ast.ExistsSubqueryExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.PatternInExpr:
if v.Sel != nil {
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query)
}
case *ast.BinaryOperationExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R)
}
case *ast.IsNullExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr)
}
case *ast.IsTruthExpr:
if v.Expr != nil {
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr)
}
case *ast.PatternLikeExpr:
if v.Pattern != nil {
v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern)
}
case *ast.CompareSubqueryExpr:
if v.L != nil {
v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L)
}
if v.R != nil {
v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R)
}
case *ast.BetweenExpr:
if v.Left != nil {
v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left)
}
if v.Right != nil {
v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right)
}
case *ast.UnaryOperationExpr:
if v.V != nil {
v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V)
}
case *ast.CaseExpr:
if v.Value != nil {
v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value)
}
if v.ElseClause != nil {
v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause)
}
}
return originalNode
}
func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode {
switch x := originalNode.(type) {
case *ast.Join:
return joinBind(x, hintedNode.(*ast.Join))
case *ast.TableSource:
ts, _ := hintedNode.(*ast.TableSource)
switch v := x.Source.(type) {
case *ast.SelectStmt:
x.Source = selectBind(v, ts.Source.(*ast.SelectStmt))
case *ast.UnionStmt:
x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt))
case *ast.TableName:
x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints
}
return x
case *ast.SelectStmt:
return selectBind(x, hintedNode.(*ast.SelectStmt))
case *ast.UnionStmt:
return unionSelectBind(x, hintedNode.(*ast.UnionStmt))
default:
return x
}
}
func joinBind(originalNode, hintedNode *ast.Join) *ast.Join {
if originalNode.Left != nil {
originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left)
}
if hintedNode.Right != nil {
originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right)
}
return originalNode
}
func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode {
selects := originalNode.SelectList.Selects
for i := len(selects) - 1; i >= 0; i-- {
originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i])
}
return originalNode
}