// Copyright 2025 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package bindinfo import ( "container/list" "fmt" "sort" "strconv" "strings" "github.com/pingcap/tidb/pkg/parser" "github.com/pingcap/tidb/pkg/parser/ast" "github.com/pingcap/tidb/pkg/planner/util/fixcontrol" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/vardef" "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/hint" ) // PlanGenerator is used to generate new Plan Candidates for this specified query. type PlanGenerator interface { Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error) } // planGenerator implements PlanGenerator. // It generates new plans via adjusting the optimizer variables and fixes. type planGenerator struct { sPool util.DestroyableSessionPool } // Generate generates new plans for the given SQL statement. func (g *planGenerator) Generate(defaultSchema, sql, charset, collation string) (plans []*BindingPlanInfo, err error) { // TODO: only support SQL starting with SELECT for now, support other types of SQLs later. // TODO: make this check more strict. sql = strings.TrimSpace(sql) prefix := "SELECT" if len(sql) < len(prefix) || strings.ToUpper(sql[:len(prefix)]) != prefix { return nil, nil // not a SELECT statement } err = callWithSCtx(g.sPool, false, func(sctx sessionctx.Context) error { genedPlans, err := generatePlanWithSCtx(sctx, defaultSchema, sql, charset, collation) if err != nil { return err } plans = make([]*BindingPlanInfo, 0, len(genedPlans)) for _, genedPlan := range genedPlans { // TODO: construct bindingSQL in a more strict way. bindingSQL := sql[:len(prefix)] + " /*+ " + genedPlan.planHints + " */ " + sql[len(prefix):] binding := &Binding{ OriginalSQL: sql, BindSQL: bindingSQL, Db: defaultSchema, Source: "generated", PlanDigest: genedPlan.planDigest, } if err := prepareHints(sctx, binding); err != nil { return err } plan := &BindingPlanInfo{ Binding: binding, Plan: genedPlan.PlanText(), } plans = append(plans, plan) } return nil }) return } type tableName struct { schema string name string } func (t *tableName) String() string { return fmt.Sprintf("%s.%s", t.schema, t.name) } // genedPlan represents a plan generated by planGenerator. type genedPlan struct { planDigest string // digest of this plan planHints string // a set of hints to reproduce this plan planText [][]string // human-readable plan text } func (gp *genedPlan) PlanText() string { sb := new(strings.Builder) for i, row := range gp.planText { if i > 0 { sb.WriteString("\n") } for j, col := range row { if j > 0 { sb.WriteString("\t") } sb.WriteString(col) } } return sb.String() } // state represents a state of the optimizer variables and fixes. type state struct { leading2 [2]*tableName // leading-2 table names varNames []string // relevant variables and their values to generate a certain plan varValues []any fixIDs []uint64 // relevant fixes and their values to generate a certain plan fixValues []string } // Encode encodes the state into a string. func (s *state) Encode() string { sb := new(strings.Builder) for _, t := range s.leading2 { if t == nil { continue } if sb.Len() > 0 { sb.WriteString(",") } sb.WriteString(t.String()) } for _, v := range s.varValues { if sb.Len() > 0 { sb.WriteString(",") } if _, isFloat := v.(float64); isFloat { // only consider 4 decimal digits, which should be enough for optimizer tuning. fmt.Fprintf(sb, "%.4f", v) continue } fmt.Fprintf(sb, "%v", v) } for _, v := range s.fixValues { if sb.Len() > 0 { sb.WriteString(",") } sb.WriteString(v) } return sb.String() } func newStateWithLeading2(old *state, leading2 [2]*tableName) *state { newState := &state{ leading2: leading2, varNames: old.varNames, varValues: old.varValues, fixIDs: old.fixIDs, fixValues: old.fixValues, } return newState } func newStateWithNewVar(old *state, varName string, varVal any) *state { newState := &state{ leading2: old.leading2, varNames: old.varNames, varValues: make([]any, len(old.varValues)), fixIDs: old.fixIDs, fixValues: old.fixValues, } copy(newState.varValues, old.varValues) for i := range newState.varNames { if newState.varNames[i] == varName { newState.varValues[i] = varVal break } } return newState } func newStateWithNewFix(old *state, fixID uint64, fixVal string) *state { newState := &state{ leading2: old.leading2, varNames: old.varNames, varValues: old.varValues, fixIDs: old.fixIDs, fixValues: make([]string, len(old.fixValues)), } copy(newState.fixValues, old.fixValues) for i := range newState.fixIDs { if newState.fixIDs[i] == fixID { newState.fixValues[i] = fixVal break } } return newState } func generatePlanWithSCtx(sctx sessionctx.Context, defaultSchema, sql, charset, collation string) (plans []*genedPlan, err error) { p := parser.New() stmt, err := p.ParseOneStmt(sql, charset, collation) if err != nil { return nil, err } sctx.GetSessionVars().CurrentDB = defaultSchema sctx.GetSessionVars().CostModelVersion = 2 // cost factor only works on cost-model v2 vars, fixes, err := RecordRelevantOptVarsAndFixes(sctx, stmt) if err != nil { return nil, err } tableNames := extractSelectTableNames(defaultSchema, stmt) possibleLeading2 := make([][2]*tableName, 0, 8) // enumerate all possible leading-2 table pairs for i := range tableNames { for j := range tableNames { if i == j { continue } possibleLeading2 = append(possibleLeading2, [2]*tableName{tableNames[i], tableNames[j]}) } } return breadthFirstPlanSearch(sctx, stmt, vars, fixes, possibleLeading2) } func breadthFirstPlanSearch(sctx sessionctx.Context, stmt ast.StmtNode, vars []string, fixes []uint64, possibleLeading2 [][2]*tableName) (plans []*genedPlan, err error) { // init BFS structures visitedStates := make(map[string]struct{}) // map[encodedState]struct{}, all visited states visitedPlans := make(map[string]*genedPlan) // map[planDigest]plan, all visited plans stateList := list.New() // states in queue to explore // init the start state and push it into the BFS list // start state: no specified leading hint + default values of all variables and fix-controls startState, err := getStartState(vars, fixes) if err != nil { return nil, err } visitedStates[startState.Encode()] = struct{}{} stateList.PushBack(startState) maxPlans, maxExploreState := 30, 5000 for len(visitedPlans) < maxPlans && len(visitedStates) < maxExploreState && stateList.Len() > 0 { currState := stateList.Remove(stateList.Front()).(*state) plan, err := genPlanUnderState(sctx, stmt, currState) if err != nil { return nil, err } visitedPlans[plan.planDigest] = plan // in each step, adjust one variable or fix or join-order for _, leading2 := range possibleLeading2 { newState := newStateWithLeading2(currState, leading2) if _, ok := visitedStates[newState.Encode()]; !ok { visitedStates[newState.Encode()] = struct{}{} stateList.PushBack(newState) } } for i := range vars { varName, varVal := vars[i], currState.varValues[i] newVarVal, err := adjustVar(varName, varVal) if err != nil { return nil, err } newState := newStateWithNewVar(currState, varName, newVarVal) if _, ok := visitedStates[newState.Encode()]; !ok { visitedStates[newState.Encode()] = struct{}{} stateList.PushBack(newState) } } for i := range fixes { fixID, fixVal := fixes[i], currState.fixValues[i] newFixVal, err := adjustFix(fixID, fixVal) if err != nil { return nil, err } newState := newStateWithNewFix(currState, fixID, newFixVal) if _, ok := visitedStates[newState.Encode()]; !ok { visitedStates[newState.Encode()] = struct{}{} stateList.PushBack(newState) } } } plans = make([]*genedPlan, 0, len(visitedPlans)) for _, plan := range visitedPlans { plans = append(plans, plan) } sort.Slice(plans, func(i, j int) bool { // to make the result stable return plans[i].planDigest < plans[j].planDigest }) return plans, nil } // genPlanUnderState returns a plan generated under the given state (vars and fix-controls). func genPlanUnderState(sctx sessionctx.Context, stmt ast.StmtNode, state *state) (plan *genedPlan, err error) { for i, varName := range state.varNames { switch varName { case vardef.TiDBOptIndexScanCostFactor: sctx.GetSessionVars().IndexScanCostFactor = state.varValues[i].(float64) case vardef.TiDBOptIndexReaderCostFactor: sctx.GetSessionVars().IndexReaderCostFactor = state.varValues[i].(float64) case vardef.TiDBOptTableReaderCostFactor: sctx.GetSessionVars().TableReaderCostFactor = state.varValues[i].(float64) case vardef.TiDBOptTableFullScanCostFactor: sctx.GetSessionVars().TableFullScanCostFactor = state.varValues[i].(float64) case vardef.TiDBOptTableRangeScanCostFactor: sctx.GetSessionVars().TableRangeScanCostFactor = state.varValues[i].(float64) case vardef.TiDBOptTableRowIDScanCostFactor: sctx.GetSessionVars().TableRowIDScanCostFactor = state.varValues[i].(float64) case vardef.TiDBOptTableTiFlashScanCostFactor: sctx.GetSessionVars().TableTiFlashScanCostFactor = state.varValues[i].(float64) case vardef.TiDBOptIndexLookupCostFactor: sctx.GetSessionVars().IndexLookupCostFactor = state.varValues[i].(float64) case vardef.TiDBOptIndexMergeCostFactor: sctx.GetSessionVars().IndexMergeCostFactor = state.varValues[i].(float64) case vardef.TiDBOptSortCostFactor: sctx.GetSessionVars().SortCostFactor = state.varValues[i].(float64) case vardef.TiDBOptTopNCostFactor: sctx.GetSessionVars().TopNCostFactor = state.varValues[i].(float64) case vardef.TiDBOptLimitCostFactor: sctx.GetSessionVars().LimitCostFactor = state.varValues[i].(float64) case vardef.TiDBOptStreamAggCostFactor: sctx.GetSessionVars().StreamAggCostFactor = state.varValues[i].(float64) case vardef.TiDBOptHashAggCostFactor: sctx.GetSessionVars().HashAggCostFactor = state.varValues[i].(float64) case vardef.TiDBOptMergeJoinCostFactor: sctx.GetSessionVars().MergeJoinCostFactor = state.varValues[i].(float64) case vardef.TiDBOptHashJoinCostFactor: sctx.GetSessionVars().HashJoinCostFactor = state.varValues[i].(float64) case vardef.TiDBOptIndexJoinCostFactor: sctx.GetSessionVars().IndexJoinCostFactor = state.varValues[i].(float64) case vardef.TiDBOptOrderingIdxSelRatio: sctx.GetSessionVars().OptOrderingIdxSelRatio = state.varValues[i].(float64) case vardef.TiDBOptRiskEqSkewRatio: sctx.GetSessionVars().RiskEqSkewRatio = state.varValues[i].(float64) case vardef.TiDBOptRiskGroupNDVSkewRatio: sctx.GetSessionVars().RiskGroupNDVSkewRatio = state.varValues[i].(float64) case vardef.TiDBOptRiskRangeSkewRatio: sctx.GetSessionVars().RiskRangeSkewRatio = state.varValues[i].(float64) case vardef.TiDBOptPreferRangeScan: sctx.GetSessionVars().SetAllowPreferRangeScan(state.varValues[i].(bool)) case vardef.TiDBOptEnableNoDecorrelateInSelect: sctx.GetSessionVars().EnableNoDecorrelateInSelect = state.varValues[i].(bool) case vardef.TiDBOptSelectivityFactor: sctx.GetSessionVars().SelectivityFactor = state.varValues[i].(float64) default: return nil, fmt.Errorf("unsupported variable %s in plan generation", varName) } } fixControlStrBuilder := strings.Builder{} for i, fixID := range state.fixIDs { if i > 0 { fixControlStrBuilder.WriteString(",") } fixControlStrBuilder.WriteString(fmt.Sprintf("%v:%v", fixID, state.fixValues[i])) } fixControlMap, _, err := fixcontrol.ParseToMap(fixControlStrBuilder.String()) if err != nil { return nil, err } sctx.GetSessionVars().OptimizerFixControl = fixControlMap // construct the leading hint and add it into the current stmtNode if state.leading2[0] != nil && state.leading2[1] != nil { if sel, isSel := stmt.(*ast.SelectStmt); isSel { defer func(hintsLen int) { sel.TableHints = sel.TableHints[:hintsLen] }(len(sel.TableHints)) leadingHint := &ast.TableOptimizerHint{ HintName: ast.NewCIStr(hint.HintLeading), Tables: []ast.HintTable{ { DBName: ast.NewCIStr(state.leading2[0].schema), TableName: ast.NewCIStr(state.leading2[0].name), }, { DBName: ast.NewCIStr(state.leading2[1].schema), TableName: ast.NewCIStr(state.leading2[1].name), }, }, } sel.TableHints = append(sel.TableHints, leadingHint) } } planDigest, planHints, planText, err := GenBriefPlanWithSCtx(sctx, stmt) if err != nil { return nil, err } return &genedPlan{ planDigest: planDigest, planText: planText, planHints: planHints, }, nil } // adjustVar returns the new value of the variable for plan generation. func adjustVar(varName string, varVal any) (newVarVal any, err error) { switch varName { case vardef.TiDBOptIndexScanCostFactor, vardef.TiDBOptIndexReaderCostFactor, vardef.TiDBOptTableReaderCostFactor, vardef.TiDBOptTableFullScanCostFactor, vardef.TiDBOptTableRangeScanCostFactor, vardef.TiDBOptTableRowIDScanCostFactor, vardef.TiDBOptTableTiFlashScanCostFactor, vardef.TiDBOptIndexLookupCostFactor, vardef.TiDBOptIndexMergeCostFactor, vardef.TiDBOptSortCostFactor, vardef.TiDBOptTopNCostFactor, vardef.TiDBOptLimitCostFactor, vardef.TiDBOptStreamAggCostFactor, vardef.TiDBOptHashAggCostFactor, vardef.TiDBOptMergeJoinCostFactor, vardef.TiDBOptHashJoinCostFactor, vardef.TiDBOptIndexJoinCostFactor: // for cost factors, we add add some penalties (5 tims of its current cost) in each step. v := varVal.(float64) if v >= 1e6 { // avoid too large penalty. return v, nil } return v * 5, nil case vardef.TiDBOptOrderingIdxSelRatio, vardef.TiDBOptRiskEqSkewRatio, vardef.TiDBOptRiskRangeSkewRatio, vardef.TiDBOptRiskGroupNDVSkewRatio, vardef.TiDBOptSelectivityFactor: // range [0, 1], "<=0" means disable v := varVal.(float64) if v <= 0 { return 0.1, nil } else if v+0.1 > 1 { return v, nil } // increase 0.1 each step return v + 0.1, nil case vardef.TiDBOptPreferRangeScan, vardef.TiDBOptEnableNoDecorrelateInSelect: // flip the switch return !varVal.(bool), nil } return nil, fmt.Errorf("unsupported variable %s in plan generation", varName) } // adjustFix returns the new value of the fix-control for plan generation. func adjustFix(fixID uint64, fixVal string) (newFixVal string, err error) { switch fixID { case fixcontrol.Fix44855, fixcontrol.Fix52869: // flip the switch fixVal = strings.ToUpper(strings.TrimSpace(fixVal)) if fixVal == vardef.Off { return vardef.On, nil } return vardef.Off, nil case fixcontrol.Fix45132: num, err := strconv.ParseInt(fixVal, 10, 64) if err != nil { return "", err } if num <= 10 { return fixVal, nil } // each time become 50% more aggressive. return fmt.Sprintf("%v", num/2), nil default: return "", fmt.Errorf("unsupported fix-control %d in plan generation", fixID) } } func getStartState(vars []string, fixes []uint64) (*state, error) { // use the default values of these vars and fix-controls as the initial state. s := &state{varNames: vars, fixIDs: fixes} for _, varName := range vars { switch varName { case vardef.TiDBOptIndexScanCostFactor: s.varValues = append(s.varValues, vardef.DefOptIndexScanCostFactor) case vardef.TiDBOptIndexReaderCostFactor: s.varValues = append(s.varValues, vardef.DefOptIndexReaderCostFactor) case vardef.TiDBOptTableReaderCostFactor: s.varValues = append(s.varValues, vardef.DefOptTableReaderCostFactor) case vardef.TiDBOptTableFullScanCostFactor: s.varValues = append(s.varValues, vardef.DefOptTableFullScanCostFactor) case vardef.TiDBOptTableRangeScanCostFactor: s.varValues = append(s.varValues, vardef.DefOptTableRangeScanCostFactor) case vardef.TiDBOptTableRowIDScanCostFactor: s.varValues = append(s.varValues, vardef.DefOptTableRowIDScanCostFactor) case vardef.TiDBOptTableTiFlashScanCostFactor: s.varValues = append(s.varValues, vardef.DefOptTableTiFlashScanCostFactor) case vardef.TiDBOptIndexLookupCostFactor: s.varValues = append(s.varValues, vardef.DefOptIndexLookupCostFactor) case vardef.TiDBOptIndexMergeCostFactor: s.varValues = append(s.varValues, vardef.DefOptIndexMergeCostFactor) case vardef.TiDBOptSortCostFactor: s.varValues = append(s.varValues, vardef.DefOptSortCostFactor) case vardef.TiDBOptTopNCostFactor: s.varValues = append(s.varValues, vardef.DefOptTopNCostFactor) case vardef.TiDBOptLimitCostFactor: s.varValues = append(s.varValues, vardef.DefOptLimitCostFactor) case vardef.TiDBOptStreamAggCostFactor: s.varValues = append(s.varValues, vardef.DefOptStreamAggCostFactor) case vardef.TiDBOptHashAggCostFactor: s.varValues = append(s.varValues, vardef.DefOptHashAggCostFactor) case vardef.TiDBOptMergeJoinCostFactor: s.varValues = append(s.varValues, vardef.DefOptMergeJoinCostFactor) case vardef.TiDBOptHashJoinCostFactor: s.varValues = append(s.varValues, vardef.DefOptHashJoinCostFactor) case vardef.TiDBOptIndexJoinCostFactor: s.varValues = append(s.varValues, vardef.DefOptIndexJoinCostFactor) case vardef.TiDBOptOrderingIdxSelRatio: s.varValues = append(s.varValues, vardef.DefTiDBOptOrderingIdxSelRatio) case vardef.TiDBOptRiskEqSkewRatio: s.varValues = append(s.varValues, vardef.DefOptRiskEqSkewRatio) case vardef.TiDBOptRiskRangeSkewRatio: s.varValues = append(s.varValues, vardef.DefOptRiskRangeSkewRatio) case vardef.TiDBOptRiskGroupNDVSkewRatio: s.varValues = append(s.varValues, vardef.DefOptRiskGroupNDVSkewRatio) case vardef.TiDBOptPreferRangeScan: s.varValues = append(s.varValues, vardef.DefOptPreferRangeScan) case vardef.TiDBOptEnableNoDecorrelateInSelect: s.varValues = append(s.varValues, vardef.DefOptEnableNoDecorrelateInSelect) case vardef.TiDBOptSelectivityFactor: s.varValues = append(s.varValues, vardef.DefOptSelectivityFactor) default: return nil, fmt.Errorf("unsupported variable %s in plan generation", varName) } } for _, fixID := range fixes { switch fixID { case fixcontrol.Fix44855: s.fixValues = append(s.fixValues, "OFF") case fixcontrol.Fix45132: s.fixValues = append(s.fixValues, "1000") case fixcontrol.Fix52869: s.fixValues = append(s.fixValues, "OFF") default: return nil, fmt.Errorf("unsupported fix-control %d in plan generation", fixID) } } return s, nil } type tableNameExtractor struct { defaultSchema string tableNames map[string]*tableName } // Enter implements ast.Visitor interface. func (e *tableNameExtractor) Enter(in ast.Node) (node ast.Node, skipChildren bool) { if name, ok := in.(*ast.TableName); ok { t := &tableName{ schema: name.Schema.L, name: name.Name.L, } if t.schema == "" { t.schema = e.defaultSchema } if _, ok := e.tableNames[t.String()]; !ok { e.tableNames[t.String()] = t } } return in, false } // Leave implements ast.Visitor interface. func (*tableNameExtractor) Leave(in ast.Node) (node ast.Node, ok bool) { return in, true } // extractSelectTableNames returns the table names in the SELECT statement. func extractSelectTableNames(defaultSchema string, node ast.StmtNode) []*tableName { selStmt, isSel := node.(*ast.SelectStmt) if !isSel { return nil // only support SELECT statement for now } extractor := &tableNameExtractor{ defaultSchema: defaultSchema, tableNames: make(map[string]*tableName), } selStmt.Accept(extractor) names := make([]*tableName, 0, len(extractor.tableNames)) for _, name := range extractor.tableNames { names = append(names, name) } sort.Slice(names, func(i, j int) bool { return names[i].String() < names[j].String() }) return names }