Files
tidb/pkg/bindinfo/binding_plan_generation.go

586 lines
20 KiB
Go

// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package bindinfo
import (
"container/list"
"fmt"
"sort"
"strconv"
"strings"
"github.com/pingcap/tidb/pkg/parser"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/planner/util/fixcontrol"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/hint"
)
// PlanGenerator is used to generate new Plan Candidates for this specified query.
type PlanGenerator interface {
Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error)
}
// planGenerator implements PlanGenerator.
// It generates new plans via adjusting the optimizer variables and fixes.
type planGenerator struct {
sPool util.DestroyableSessionPool
}
// Generate generates new plans for the given SQL statement.
func (g *planGenerator) Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error) {
// TODO: only support SQL starting with SELECT for now, support other types of SQLs later.
// TODO: make this check more strict.
sql = strings.TrimSpace(sql)
prefix := "SELECT"
if len(sql) < len(prefix) || strings.ToUpper(sql[:len(prefix)]) != prefix {
return nil, nil // not a SELECT statement
}
err = callWithSCtx(g.sPool, false, func(sctx sessionctx.Context) error {
genedPlans, err := generatePlanWithSCtx(sctx, defaultSchema, sql, charset, collation)
if err != nil {
return err
}
plans = make([]*BindingPlanInfo, 0, len(genedPlans))
for _, genedPlan := range genedPlans {
// TODO: construct bindingSQL in a more strict way.
bindingSQL := sql[:len(prefix)] + " /*+ " + genedPlan.planHints + " */ " + sql[len(prefix):]
binding := &Binding{
OriginalSQL: sql,
BindSQL: bindingSQL,
Db: defaultSchema,
Source: "generated",
PlanDigest: genedPlan.planDigest,
}
if err := prepareHints(sctx, binding); err != nil {
return err
}
plan := &BindingPlanInfo{
Binding: binding,
Plan: genedPlan.PlanText(),
}
plans = append(plans, plan)
}
return nil
})
return
}
type tableName struct {
schema string
name string
}
func (t *tableName) String() string {
return fmt.Sprintf("%s.%s", t.schema, t.name)
}
// genedPlan represents a plan generated by planGenerator.
type genedPlan struct {
planDigest string // digest of this plan
planHints string // a set of hints to reproduce this plan
planText [][]string // human-readable plan text
}
func (gp *genedPlan) PlanText() string {
sb := new(strings.Builder)
for i, row := range gp.planText {
if i > 0 {
sb.WriteString("\n")
}
for j, col := range row {
if j > 0 {
sb.WriteString("\t")
}
sb.WriteString(col)
}
}
return sb.String()
}
// state represents a state of the optimizer variables and fixes.
type state struct {
leading2 [2]*tableName // leading-2 table names
varNames []string // relevant variables and their values to generate a certain plan
varValues []any
fixIDs []uint64 // relevant fixes and their values to generate a certain plan
fixValues []string
}
// Encode encodes the state into a string.
func (s *state) Encode() string {
sb := new(strings.Builder)
for _, t := range s.leading2 {
if t == nil {
continue
}
if sb.Len() > 0 {
sb.WriteString(",")
}
sb.WriteString(t.String())
}
for _, v := range s.varValues {
if sb.Len() > 0 {
sb.WriteString(",")
}
if _, isFloat := v.(float64); isFloat {
// only consider 4 decimal digits, which should be enough for optimizer tuning.
fmt.Fprintf(sb, "%.4f", v)
continue
}
fmt.Fprintf(sb, "%v", v)
}
for _, v := range s.fixValues {
if sb.Len() > 0 {
sb.WriteString(",")
}
sb.WriteString(v)
}
return sb.String()
}
func newStateWithLeading2(old *state, leading2 [2]*tableName) *state {
newState := &state{
leading2: leading2,
varNames: old.varNames,
varValues: old.varValues,
fixIDs: old.fixIDs,
fixValues: old.fixValues,
}
return newState
}
func newStateWithNewVar(old *state, varName string, varVal any) *state {
newState := &state{
leading2: old.leading2,
varNames: old.varNames,
varValues: make([]any, len(old.varValues)),
fixIDs: old.fixIDs,
fixValues: old.fixValues,
}
copy(newState.varValues, old.varValues)
for i := range newState.varNames {
if newState.varNames[i] == varName {
newState.varValues[i] = varVal
break
}
}
return newState
}
func newStateWithNewFix(old *state, fixID uint64, fixVal string) *state {
newState := &state{
leading2: old.leading2,
varNames: old.varNames,
varValues: old.varValues,
fixIDs: old.fixIDs,
fixValues: make([]string, len(old.fixValues)),
}
copy(newState.fixValues, old.fixValues)
for i := range newState.fixIDs {
if newState.fixIDs[i] == fixID {
newState.fixValues[i] = fixVal
break
}
}
return newState
}
func generatePlanWithSCtx(sctx sessionctx.Context, defaultSchema, sql, charset, collation string) (plans []*genedPlan, err error) {
p := parser.New()
stmt, err := p.ParseOneStmt(sql, charset, collation)
if err != nil {
return nil, err
}
sctx.GetSessionVars().CurrentDB = defaultSchema
sctx.GetSessionVars().CostModelVersion = 2 // cost factor only works on cost-model v2
vars, fixes, err := RecordRelevantOptVarsAndFixes(sctx, stmt)
if err != nil {
return nil, err
}
tableNames := extractSelectTableNames(defaultSchema, stmt)
possibleLeading2 := make([][2]*tableName, 0, 8) // enumerate all possible leading-2 table pairs
for i := range tableNames {
for j := range tableNames {
if i == j {
continue
}
possibleLeading2 = append(possibleLeading2, [2]*tableName{tableNames[i], tableNames[j]})
}
}
return breadthFirstPlanSearch(sctx, stmt, vars, fixes, possibleLeading2)
}
func breadthFirstPlanSearch(sctx sessionctx.Context, stmt ast.StmtNode,
vars []string, fixes []uint64, possibleLeading2 [][2]*tableName) (plans []*genedPlan, err error) {
// init BFS structures
visitedStates := make(map[string]struct{}) // map[encodedState]struct{}, all visited states
visitedPlans := make(map[string]*genedPlan) // map[planDigest]plan, all visited plans
stateList := list.New() // states in queue to explore
// init the start state and push it into the BFS list
// start state: no specified leading hint + default values of all variables and fix-controls
startState, err := getStartState(vars, fixes)
if err != nil {
return nil, err
}
visitedStates[startState.Encode()] = struct{}{}
stateList.PushBack(startState)
maxPlans, maxExploreState := 30, 5000
for len(visitedPlans) < maxPlans && len(visitedStates) < maxExploreState && stateList.Len() > 0 {
currState := stateList.Remove(stateList.Front()).(*state)
plan, err := genPlanUnderState(sctx, stmt, currState)
if err != nil {
return nil, err
}
visitedPlans[plan.planDigest] = plan
// in each step, adjust one variable or fix or join-order
for _, leading2 := range possibleLeading2 {
newState := newStateWithLeading2(currState, leading2)
if _, ok := visitedStates[newState.Encode()]; !ok {
visitedStates[newState.Encode()] = struct{}{}
stateList.PushBack(newState)
}
}
for i := range vars {
varName, varVal := vars[i], currState.varValues[i]
newVarVal, err := adjustVar(varName, varVal)
if err != nil {
return nil, err
}
newState := newStateWithNewVar(currState, varName, newVarVal)
if _, ok := visitedStates[newState.Encode()]; !ok {
visitedStates[newState.Encode()] = struct{}{}
stateList.PushBack(newState)
}
}
for i := range fixes {
fixID, fixVal := fixes[i], currState.fixValues[i]
newFixVal, err := adjustFix(fixID, fixVal)
if err != nil {
return nil, err
}
newState := newStateWithNewFix(currState, fixID, newFixVal)
if _, ok := visitedStates[newState.Encode()]; !ok {
visitedStates[newState.Encode()] = struct{}{}
stateList.PushBack(newState)
}
}
}
plans = make([]*genedPlan, 0, len(visitedPlans))
for _, plan := range visitedPlans {
plans = append(plans, plan)
}
sort.Slice(plans, func(i, j int) bool { // to make the result stable
return plans[i].planDigest < plans[j].planDigest
})
return plans, nil
}
// genPlanUnderState returns a plan generated under the given state (vars and fix-controls).
func genPlanUnderState(sctx sessionctx.Context, stmt ast.StmtNode, state *state) (plan *genedPlan, err error) {
for i, varName := range state.varNames {
switch varName {
case vardef.TiDBOptIndexScanCostFactor:
sctx.GetSessionVars().IndexScanCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptIndexReaderCostFactor:
sctx.GetSessionVars().IndexReaderCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptTableReaderCostFactor:
sctx.GetSessionVars().TableReaderCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptTableFullScanCostFactor:
sctx.GetSessionVars().TableFullScanCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptTableRangeScanCostFactor:
sctx.GetSessionVars().TableRangeScanCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptTableRowIDScanCostFactor:
sctx.GetSessionVars().TableRowIDScanCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptTableTiFlashScanCostFactor:
sctx.GetSessionVars().TableTiFlashScanCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptIndexLookupCostFactor:
sctx.GetSessionVars().IndexLookupCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptIndexMergeCostFactor:
sctx.GetSessionVars().IndexMergeCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptSortCostFactor:
sctx.GetSessionVars().SortCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptTopNCostFactor:
sctx.GetSessionVars().TopNCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptLimitCostFactor:
sctx.GetSessionVars().LimitCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptStreamAggCostFactor:
sctx.GetSessionVars().StreamAggCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptHashAggCostFactor:
sctx.GetSessionVars().HashAggCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptMergeJoinCostFactor:
sctx.GetSessionVars().MergeJoinCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptHashJoinCostFactor:
sctx.GetSessionVars().HashJoinCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptIndexJoinCostFactor:
sctx.GetSessionVars().IndexJoinCostFactor = state.varValues[i].(float64)
case vardef.TiDBOptOrderingIdxSelRatio:
sctx.GetSessionVars().OptOrderingIdxSelRatio = state.varValues[i].(float64)
case vardef.TiDBOptRiskEqSkewRatio:
sctx.GetSessionVars().RiskEqSkewRatio = state.varValues[i].(float64)
case vardef.TiDBOptRiskGroupNDVSkewRatio:
sctx.GetSessionVars().RiskGroupNDVSkewRatio = state.varValues[i].(float64)
case vardef.TiDBOptRiskRangeSkewRatio:
sctx.GetSessionVars().RiskRangeSkewRatio = state.varValues[i].(float64)
case vardef.TiDBOptPreferRangeScan:
sctx.GetSessionVars().SetAllowPreferRangeScan(state.varValues[i].(bool))
case vardef.TiDBOptEnableNoDecorrelateInSelect:
sctx.GetSessionVars().EnableNoDecorrelateInSelect = state.varValues[i].(bool)
case vardef.TiDBOptEnableSemiJoinRewrite:
sctx.GetSessionVars().EnableSemiJoinRewrite = state.varValues[i].(bool)
case vardef.TiDBOptSelectivityFactor:
sctx.GetSessionVars().SelectivityFactor = state.varValues[i].(float64)
default:
return nil, fmt.Errorf("unsupported variable %s in plan generation", varName)
}
}
fixControlStrBuilder := strings.Builder{}
for i, fixID := range state.fixIDs {
if i > 0 {
fixControlStrBuilder.WriteString(",")
}
fixControlStrBuilder.WriteString(fmt.Sprintf("%v:%v", fixID, state.fixValues[i]))
}
fixControlMap, _, err := fixcontrol.ParseToMap(fixControlStrBuilder.String())
if err != nil {
return nil, err
}
sctx.GetSessionVars().OptimizerFixControl = fixControlMap
// construct the leading hint and add it into the current stmtNode
if state.leading2[0] != nil && state.leading2[1] != nil {
if sel, isSel := stmt.(*ast.SelectStmt); isSel {
defer func(hintsLen int) {
sel.TableHints = sel.TableHints[:hintsLen]
}(len(sel.TableHints))
leadingHint := &ast.TableOptimizerHint{
HintName: ast.NewCIStr(hint.HintLeading),
Tables: []ast.HintTable{
{
DBName: ast.NewCIStr(state.leading2[0].schema),
TableName: ast.NewCIStr(state.leading2[0].name),
},
{
DBName: ast.NewCIStr(state.leading2[1].schema),
TableName: ast.NewCIStr(state.leading2[1].name),
},
},
}
sel.TableHints = append(sel.TableHints, leadingHint)
}
}
planDigest, planHints, planText, err := GenBriefPlanWithSCtx(sctx, stmt)
if err != nil {
return nil, err
}
return &genedPlan{
planDigest: planDigest,
planText: planText,
planHints: planHints,
}, nil
}
// adjustVar returns the new value of the variable for plan generation.
func adjustVar(varName string, varVal any) (newVarVal any, err error) {
switch varName {
case vardef.TiDBOptIndexScanCostFactor, vardef.TiDBOptIndexReaderCostFactor, vardef.TiDBOptTableReaderCostFactor,
vardef.TiDBOptTableFullScanCostFactor, vardef.TiDBOptTableRangeScanCostFactor, vardef.TiDBOptTableRowIDScanCostFactor,
vardef.TiDBOptTableTiFlashScanCostFactor, vardef.TiDBOptIndexLookupCostFactor, vardef.TiDBOptIndexMergeCostFactor,
vardef.TiDBOptSortCostFactor, vardef.TiDBOptTopNCostFactor, vardef.TiDBOptLimitCostFactor,
vardef.TiDBOptStreamAggCostFactor, vardef.TiDBOptHashAggCostFactor, vardef.TiDBOptMergeJoinCostFactor,
vardef.TiDBOptHashJoinCostFactor, vardef.TiDBOptIndexJoinCostFactor:
// for cost factors, we add add some penalties (5 tims of its current cost) in each step.
v := varVal.(float64)
if v >= 1e6 { // avoid too large penalty.
return v, nil
}
return v * 5, nil
case vardef.TiDBOptOrderingIdxSelRatio, vardef.TiDBOptRiskEqSkewRatio, vardef.TiDBOptRiskRangeSkewRatio, vardef.TiDBOptRiskGroupNDVSkewRatio, vardef.TiDBOptSelectivityFactor: // range [0, 1], "<=0" means disable
v := varVal.(float64)
if v <= 0 {
return 0.1, nil
} else if v+0.1 > 1 {
return v, nil
}
// increase 0.1 each step
return v + 0.1, nil
case vardef.TiDBOptPreferRangeScan, vardef.TiDBOptEnableNoDecorrelateInSelect, vardef.TiDBOptAlwaysKeepJoinKey, vardef.TiDBOptEnableSemiJoinRewrite: // flip the switch
return !varVal.(bool), nil
}
return nil, fmt.Errorf("unsupported variable %s in plan generation", varName)
}
// adjustFix returns the new value of the fix-control for plan generation.
func adjustFix(fixID uint64, fixVal string) (newFixVal string, err error) {
switch fixID {
case fixcontrol.Fix44855, fixcontrol.Fix52869: // flip the switch
fixVal = strings.ToUpper(strings.TrimSpace(fixVal))
if fixVal == vardef.Off {
return vardef.On, nil
}
return vardef.Off, nil
case fixcontrol.Fix45132:
num, err := strconv.ParseInt(fixVal, 10, 64)
if err != nil {
return "", err
}
if num <= 10 {
return fixVal, nil
}
// each time become 50% more aggressive.
return fmt.Sprintf("%v", num/2), nil
default:
return "", fmt.Errorf("unsupported fix-control %d in plan generation", fixID)
}
}
func getStartState(vars []string, fixes []uint64) (*state, error) {
// use the default values of these vars and fix-controls as the initial state.
s := &state{varNames: vars, fixIDs: fixes}
for _, varName := range vars {
switch varName {
case vardef.TiDBOptIndexScanCostFactor:
s.varValues = append(s.varValues, vardef.DefOptIndexScanCostFactor)
case vardef.TiDBOptIndexReaderCostFactor:
s.varValues = append(s.varValues, vardef.DefOptIndexReaderCostFactor)
case vardef.TiDBOptTableReaderCostFactor:
s.varValues = append(s.varValues, vardef.DefOptTableReaderCostFactor)
case vardef.TiDBOptTableFullScanCostFactor:
s.varValues = append(s.varValues, vardef.DefOptTableFullScanCostFactor)
case vardef.TiDBOptTableRangeScanCostFactor:
s.varValues = append(s.varValues, vardef.DefOptTableRangeScanCostFactor)
case vardef.TiDBOptTableRowIDScanCostFactor:
s.varValues = append(s.varValues, vardef.DefOptTableRowIDScanCostFactor)
case vardef.TiDBOptTableTiFlashScanCostFactor:
s.varValues = append(s.varValues, vardef.DefOptTableTiFlashScanCostFactor)
case vardef.TiDBOptIndexLookupCostFactor:
s.varValues = append(s.varValues, vardef.DefOptIndexLookupCostFactor)
case vardef.TiDBOptIndexMergeCostFactor:
s.varValues = append(s.varValues, vardef.DefOptIndexMergeCostFactor)
case vardef.TiDBOptSortCostFactor:
s.varValues = append(s.varValues, vardef.DefOptSortCostFactor)
case vardef.TiDBOptTopNCostFactor:
s.varValues = append(s.varValues, vardef.DefOptTopNCostFactor)
case vardef.TiDBOptLimitCostFactor:
s.varValues = append(s.varValues, vardef.DefOptLimitCostFactor)
case vardef.TiDBOptStreamAggCostFactor:
s.varValues = append(s.varValues, vardef.DefOptStreamAggCostFactor)
case vardef.TiDBOptHashAggCostFactor:
s.varValues = append(s.varValues, vardef.DefOptHashAggCostFactor)
case vardef.TiDBOptMergeJoinCostFactor:
s.varValues = append(s.varValues, vardef.DefOptMergeJoinCostFactor)
case vardef.TiDBOptHashJoinCostFactor:
s.varValues = append(s.varValues, vardef.DefOptHashJoinCostFactor)
case vardef.TiDBOptIndexJoinCostFactor:
s.varValues = append(s.varValues, vardef.DefOptIndexJoinCostFactor)
case vardef.TiDBOptOrderingIdxSelRatio:
s.varValues = append(s.varValues, vardef.DefTiDBOptOrderingIdxSelRatio)
case vardef.TiDBOptRiskEqSkewRatio:
s.varValues = append(s.varValues, vardef.DefOptRiskEqSkewRatio)
case vardef.TiDBOptRiskRangeSkewRatio:
s.varValues = append(s.varValues, vardef.DefOptRiskRangeSkewRatio)
case vardef.TiDBOptRiskGroupNDVSkewRatio:
s.varValues = append(s.varValues, vardef.DefOptRiskGroupNDVSkewRatio)
case vardef.TiDBOptPreferRangeScan:
s.varValues = append(s.varValues, vardef.DefOptPreferRangeScan)
case vardef.TiDBOptEnableNoDecorrelateInSelect:
s.varValues = append(s.varValues, vardef.DefOptEnableNoDecorrelateInSelect)
case vardef.TiDBOptEnableSemiJoinRewrite:
s.varValues = append(s.varValues, vardef.DefOptEnableSemiJoinRewrite)
case vardef.TiDBOptAlwaysKeepJoinKey:
s.varValues = append(s.varValues, vardef.DefOptAlwaysKeepJoinKey)
case vardef.TiDBOptSelectivityFactor:
s.varValues = append(s.varValues, vardef.DefOptSelectivityFactor)
case vardef.TiDBOptCartesianJoinOrderThreshold:
s.varValues = append(s.varValues, vardef.DefOptCartesianJoinOrderThreshold)
default:
return nil, fmt.Errorf("unsupported variable %s in plan generation", varName)
}
}
for _, fixID := range fixes {
switch fixID {
case fixcontrol.Fix44855:
s.fixValues = append(s.fixValues, "OFF")
case fixcontrol.Fix45132:
s.fixValues = append(s.fixValues, "1000")
case fixcontrol.Fix52869:
s.fixValues = append(s.fixValues, "OFF")
default:
return nil, fmt.Errorf("unsupported fix-control %d in plan generation", fixID)
}
}
return s, nil
}
type tableNameExtractor struct {
defaultSchema string
tableNames map[string]*tableName
}
// Enter implements ast.Visitor interface.
func (e *tableNameExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) {
if name, ok := in.(*ast.TableName); ok {
t := &tableName{
schema: name.Schema.L,
name: name.Name.L,
}
if t.schema == "" {
t.schema = e.defaultSchema
}
if _, ok := e.tableNames[t.String()]; !ok {
e.tableNames[t.String()] = t
}
}
return in, false
}
// Leave implements ast.Visitor interface.
func (*tableNameExtractor) Leave(in ast.Node) (node ast.Node, ok bool) {
return in, true
}
// extractSelectTableNames returns the table names in the SELECT statement.
func extractSelectTableNames(defaultSchema string, node ast.StmtNode) []*tableName {
selStmt, isSel := node.(*ast.SelectStmt)
if !isSel {
return nil // only support SELECT statement for now
}
extractor := &tableNameExtractor{
defaultSchema: defaultSchema,
tableNames: make(map[string]*tableName),
}
selStmt.Accept(extractor)
names := make([]*tableName, 0, len(extractor.tableNames))
for _, name := range extractor.tableNames {
names = append(names, name)
}
sort.Slice(names, func(i, j int) bool {
return names[i].String() < names[j].String()
})
return names
}