170 lines
5.3 KiB
Go
170 lines
5.3 KiB
Go
// Copyright 2019 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package bindinfo
|
|
|
|
import "github.com/pingcap/parser/ast"
|
|
|
|
// BindHint will add hints for originStmt according to hintedStmt' hints.
|
|
func BindHint(originStmt, hintedStmt ast.StmtNode) ast.StmtNode {
|
|
switch x := originStmt.(type) {
|
|
case *ast.SelectStmt:
|
|
return selectBind(x, hintedStmt.(*ast.SelectStmt))
|
|
default:
|
|
return originStmt
|
|
}
|
|
}
|
|
|
|
func selectBind(originalNode, hintedNode *ast.SelectStmt) *ast.SelectStmt {
|
|
if hintedNode.TableHints != nil {
|
|
originalNode.TableHints = hintedNode.TableHints
|
|
}
|
|
if originalNode.From != nil {
|
|
originalNode.From.TableRefs = resultSetNodeBind(originalNode.From.TableRefs, hintedNode.From.TableRefs).(*ast.Join)
|
|
}
|
|
if originalNode.Where != nil {
|
|
originalNode.Where = exprBind(originalNode.Where, hintedNode.Where).(ast.ExprNode)
|
|
}
|
|
|
|
if originalNode.Having != nil {
|
|
originalNode.Having.Expr = exprBind(originalNode.Having.Expr, hintedNode.Having.Expr)
|
|
}
|
|
|
|
if originalNode.OrderBy != nil {
|
|
originalNode.OrderBy = orderByBind(originalNode.OrderBy, hintedNode.OrderBy)
|
|
}
|
|
|
|
if originalNode.Fields != nil {
|
|
origFields := originalNode.Fields.Fields
|
|
hintFields := hintedNode.Fields.Fields
|
|
for idx := range origFields {
|
|
origFields[idx].Expr = exprBind(origFields[idx].Expr, hintFields[idx].Expr)
|
|
}
|
|
}
|
|
return originalNode
|
|
}
|
|
|
|
func orderByBind(originalNode, hintedNode *ast.OrderByClause) *ast.OrderByClause {
|
|
for idx := 0; idx < len(originalNode.Items); idx++ {
|
|
originalNode.Items[idx].Expr = exprBind(originalNode.Items[idx].Expr, hintedNode.Items[idx].Expr)
|
|
}
|
|
return originalNode
|
|
}
|
|
|
|
func exprBind(originalNode, hintedNode ast.ExprNode) ast.ExprNode {
|
|
switch v := originalNode.(type) {
|
|
case *ast.SubqueryExpr:
|
|
if v.Query != nil {
|
|
v.Query = resultSetNodeBind(v.Query, hintedNode.(*ast.SubqueryExpr).Query)
|
|
}
|
|
case *ast.ExistsSubqueryExpr:
|
|
if v.Sel != nil {
|
|
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.ExistsSubqueryExpr).Sel.(*ast.SubqueryExpr).Query)
|
|
}
|
|
case *ast.PatternInExpr:
|
|
if v.Sel != nil {
|
|
v.Sel.(*ast.SubqueryExpr).Query = resultSetNodeBind(v.Sel.(*ast.SubqueryExpr).Query, hintedNode.(*ast.PatternInExpr).Sel.(*ast.SubqueryExpr).Query)
|
|
}
|
|
case *ast.BinaryOperationExpr:
|
|
if v.L != nil {
|
|
v.L = exprBind(v.L, hintedNode.(*ast.BinaryOperationExpr).L)
|
|
}
|
|
if v.R != nil {
|
|
v.R = exprBind(v.R, hintedNode.(*ast.BinaryOperationExpr).R)
|
|
}
|
|
case *ast.IsNullExpr:
|
|
if v.Expr != nil {
|
|
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsNullExpr).Expr)
|
|
}
|
|
case *ast.IsTruthExpr:
|
|
if v.Expr != nil {
|
|
v.Expr = exprBind(v.Expr, hintedNode.(*ast.IsTruthExpr).Expr)
|
|
}
|
|
case *ast.PatternLikeExpr:
|
|
if v.Pattern != nil {
|
|
v.Pattern = exprBind(v.Pattern, hintedNode.(*ast.PatternLikeExpr).Pattern)
|
|
}
|
|
case *ast.CompareSubqueryExpr:
|
|
if v.L != nil {
|
|
v.L = exprBind(v.L, hintedNode.(*ast.CompareSubqueryExpr).L)
|
|
}
|
|
if v.R != nil {
|
|
v.R = exprBind(v.R, hintedNode.(*ast.CompareSubqueryExpr).R)
|
|
}
|
|
case *ast.BetweenExpr:
|
|
if v.Left != nil {
|
|
v.Left = exprBind(v.Left, hintedNode.(*ast.BetweenExpr).Left)
|
|
}
|
|
if v.Right != nil {
|
|
v.Right = exprBind(v.Right, hintedNode.(*ast.BetweenExpr).Right)
|
|
}
|
|
case *ast.UnaryOperationExpr:
|
|
if v.V != nil {
|
|
v.V = exprBind(v.V, hintedNode.(*ast.UnaryOperationExpr).V)
|
|
}
|
|
case *ast.CaseExpr:
|
|
if v.Value != nil {
|
|
v.Value = exprBind(v.Value, hintedNode.(*ast.CaseExpr).Value)
|
|
}
|
|
if v.ElseClause != nil {
|
|
v.ElseClause = exprBind(v.ElseClause, hintedNode.(*ast.CaseExpr).ElseClause)
|
|
}
|
|
}
|
|
return originalNode
|
|
}
|
|
|
|
func resultSetNodeBind(originalNode, hintedNode ast.ResultSetNode) ast.ResultSetNode {
|
|
switch x := originalNode.(type) {
|
|
case *ast.Join:
|
|
return joinBind(x, hintedNode.(*ast.Join))
|
|
case *ast.TableSource:
|
|
ts, _ := hintedNode.(*ast.TableSource)
|
|
switch v := x.Source.(type) {
|
|
case *ast.SelectStmt:
|
|
x.Source = selectBind(v, ts.Source.(*ast.SelectStmt))
|
|
case *ast.UnionStmt:
|
|
x.Source = unionSelectBind(v, hintedNode.(*ast.TableSource).Source.(*ast.UnionStmt))
|
|
case *ast.TableName:
|
|
x.Source.(*ast.TableName).IndexHints = ts.Source.(*ast.TableName).IndexHints
|
|
}
|
|
return x
|
|
case *ast.SelectStmt:
|
|
return selectBind(x, hintedNode.(*ast.SelectStmt))
|
|
case *ast.UnionStmt:
|
|
return unionSelectBind(x, hintedNode.(*ast.UnionStmt))
|
|
default:
|
|
return x
|
|
}
|
|
}
|
|
|
|
func joinBind(originalNode, hintedNode *ast.Join) *ast.Join {
|
|
if originalNode.Left != nil {
|
|
originalNode.Left = resultSetNodeBind(originalNode.Left, hintedNode.Left)
|
|
}
|
|
|
|
if hintedNode.Right != nil {
|
|
originalNode.Right = resultSetNodeBind(originalNode.Right, hintedNode.Right)
|
|
}
|
|
|
|
return originalNode
|
|
}
|
|
|
|
func unionSelectBind(originalNode, hintedNode *ast.UnionStmt) ast.ResultSetNode {
|
|
selects := originalNode.SelectList.Selects
|
|
for i := len(selects) - 1; i >= 0; i-- {
|
|
originalNode.SelectList.Selects[i] = selectBind(selects[i], hintedNode.SelectList.Selects[i])
|
|
}
|
|
|
|
return originalNode
|
|
}
|