Files
tidb/planner/core/plan_cache_param.go

205 lines
6.3 KiB
Go

// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package core
import (
"context"
"errors"
"strings"
"sync"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/parser"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/format"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
driver "github.com/pingcap/tidb/types/parser_driver"
)
var (
paramReplacerPool = sync.Pool{New: func() interface{} {
pr := new(paramReplacer)
pr.Reset()
return pr
}}
paramRestorerPool = sync.Pool{New: func() interface{} {
pr := new(paramRestorer)
pr.Reset()
return pr
}}
paramCtxPool = sync.Pool{New: func() interface{} {
buf := new(strings.Builder)
buf.Reset()
restoreCtx := format.NewRestoreCtx(format.DefaultRestoreFlags, buf)
restoreCtx.Flags ^= format.RestoreKeyWordUppercase
restoreCtx.Flags ^= format.RestoreNameBackQuotes
restoreCtx.Flags |= format.RestoreStringWithoutCharset
return restoreCtx
}}
)
// paramReplacer is an ast.Visitor that replaces all values with `?` and collects them.
type paramReplacer struct {
params []*driver.ValueExpr
// Skip all values in SelectField, e.g.
// `select a+1 from t where a<10 and b<23` should be parameterized to
// `select a+1 from t where a<? and b<?`, instead of
// `select a+? from t where a<? and b<?`.
// This is to make the output field names be corresponding to these values.
// Use int instead of bool to support nested SelectField.
selFieldsCnt int
// Skip all values in GroupByClause since them can affect the full_group_by check, e.g.
// `select a*2 from t group by a*?` cannot pass the full_group_by check.
groupByCnt int
}
func (pr *paramReplacer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch n := in.(type) {
case *ast.SelectField:
pr.selFieldsCnt++
case *ast.GroupByClause:
pr.groupByCnt++
case *driver.ValueExpr:
if pr.selFieldsCnt == 0 && // not in SelectField
pr.groupByCnt == 0 { // not in GroupBy
pr.params = append(pr.params, n)
param := ast.NewParamMarkerExpr(len(pr.params) - 1) // offset is used as order in non-prepared plan cache.
param.(*driver.ParamMarkerExpr).Datum = *n.Datum.Clone() // init the ParamMakerExpr's Datum
return param, true
}
}
return in, false
}
func (pr *paramReplacer) Leave(in ast.Node) (out ast.Node, ok bool) {
switch in.(type) {
case *ast.SelectField:
pr.selFieldsCnt--
case *ast.GroupByClause:
pr.groupByCnt--
}
return in, true
}
func (pr *paramReplacer) Reset() { pr.params, pr.selFieldsCnt, pr.groupByCnt = nil, 0, 0 }
// GetParamSQLFromAST returns the parameterized SQL of this AST.
// NOTICE: this function does not modify the original AST.
func GetParamSQLFromAST(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode) (paramSQL string, params []*driver.ValueExpr, err error) {
paramSQL, params, err = ParameterizeAST(ctx, sctx, stmt)
if err != nil {
return "", nil, err
}
err = RestoreASTWithParams(ctx, sctx, stmt, params)
return
}
// ParameterizeAST parameterizes this StmtNode.
// e.g. `select * from t where a<10 and b<23` --> `select * from t where a<? and b<?`, [10, 23].
// NOTICE: this function may modify the input stmt.
func ParameterizeAST(ctx context.Context, sctx sessionctx.Context, stmt ast.StmtNode) (paramSQL string, params []*driver.ValueExpr, err error) {
pr := paramReplacerPool.Get().(*paramReplacer)
pCtx := paramCtxPool.Get().(*format.RestoreCtx)
defer func() {
pr.Reset()
paramReplacerPool.Put(pr)
pCtx.In.(*strings.Builder).Reset()
paramCtxPool.Put(pCtx)
}()
stmt.Accept(pr)
if err := stmt.Restore(pCtx); err != nil {
return "", nil, err
}
paramSQL, params = pCtx.In.(*strings.Builder).String(), pr.params
return
}
type paramRestorer struct {
params []*driver.ValueExpr
err error
}
func (pr *paramRestorer) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
switch n := in.(type) {
case *driver.ParamMarkerExpr:
if n.Offset >= len(pr.params) {
pr.err = errors.New("failed to restore ast.Node")
return nil, true
}
// offset is used as order in non-prepared plan cache.
return pr.params[n.Offset], true
}
if pr.err != nil {
return nil, true
}
return in, false
}
func (pr *paramRestorer) Leave(in ast.Node) (out ast.Node, ok bool) {
return in, true
}
func (pr *paramRestorer) Reset() {
pr.params, pr.err = nil, nil
}
// RestoreASTWithParams restore this parameterized AST with specific parameters.
// e.g. `select * from t where a<? and b<?`, [10, 23] --> `select * from t where a<10 and b<23`.
func RestoreASTWithParams(ctx context.Context, _ sessionctx.Context, stmt ast.StmtNode, params []*driver.ValueExpr) error {
pr := paramRestorerPool.Get().(*paramRestorer)
defer func() {
pr.Reset()
paramRestorerPool.Put(pr)
}()
pr.params = params
stmt.Accept(pr)
return pr.err
}
// Params2Expressions converts these parameters to an expression list.
func Params2Expressions(params []*driver.ValueExpr) []expression.Expression {
exprs := make([]expression.Expression, 0, len(params))
for _, p := range params {
tp := new(types.FieldType)
types.InferParamTypeFromDatum(&p.Datum, tp)
exprs = append(exprs, &expression.Constant{
Value: p.Datum,
RetType: tp,
})
}
return exprs
}
var parserPool = &sync.Pool{New: func() interface{} { return parser.New() }}
// ParseParameterizedSQL parse this parameterized SQL with the specified sctx.
func ParseParameterizedSQL(sctx sessionctx.Context, paramSQL string) (ast.StmtNode, error) {
p := parserPool.Get().(*parser.Parser)
defer parserPool.Put(p)
p.SetSQLMode(sctx.GetSessionVars().SQLMode)
p.SetParserConfig(sctx.GetSessionVars().BuildParserConfig())
tmp, _, err := p.ParseSQL(paramSQL, sctx.GetSessionVars().GetParseParams()...)
if err != nil {
return nil, err
}
if len(tmp) != 1 {
return nil, errors.New("unexpected multiple statements")
}
return tmp[0], nil
}