Files
tidb/pkg/workloadlearning/handle.go
2025-11-18 07:19:41 +00:00

729 lines
28 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package workloadlearning implements the Workload-Based Learning Optimizer.
// The Workload-Based Learning Optimizer introduces a new module in TiDB that leverages captured workload history to
// enhance the database query optimizer.
// By learning from historical data, this module helps the optimizer make smarter decisions, such as identify hot and cold tables,
// analyze resource consumption, etc.
// The workload analysis results can be used to directly suggest a better path,
// or to indirectly influence the cost model and stats so that the optimizer can select the best plan more intelligently and adaptively.
package workloadlearning
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/pingcap/tidb/pkg/infoschema"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessiontxn"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/plancodec"
"github.com/pingcap/tidb/pkg/util/sqlescape"
"github.com/pingcap/tidb/pkg/util/sqlexec"
"github.com/pingcap/tipb/go-tipb"
"go.uber.org/zap"
)
const batchInsertSize = 1000
// The default memory usage of point get and batch point get
// Unit is byte
const defaultPointGetMemUsage = 1
const (
// The category of workload-based learning
feedbackCategory = "Feedback"
)
const (
// The type of workload-based learning
tableReadCost = "TableReadCost"
)
const (
// The workload repo related constants
// Copied from pkg/util/workloadrepo/const.go
// If we directly import the const.go, it will cause import cycle. (domain -> workloadlearning-> workloadrepo-> domain)
// The workloadrepo should be changed to (domain -> workloadrepo), but for now just copied those constant to avoid import cycle.
defSnapshotInterval = 3600
histSnapshotsTable = "HIST_SNAPSHOTS"
)
const (
workloadQueryTemplate = "select summary_end.DIGEST, summary_end.DIGEST_TEXT, summary_end.BINARY_PLAN, " +
"if(summary_start.EXEC_COUNT is NULL, summary_end.EXEC_COUNT, summary_end.EXEC_COUNT - summary_start.EXEC_COUNT) as EXEC_COUNT FROM " +
"(SELECT DIGEST, DIGEST_TEXT, BINARY_PLAN, EXEC_COUNT " +
"FROM %?.HIST_TIDB_STATEMENTS_STATS WHERE STMT_TYPE = 'Select' AND SNAP_ID = %?) summary_end " +
"left join (SELECT DIGEST, DIGEST_TEXT, BINARY_PLAN, EXEC_COUNT " +
"FROM %?.HIST_TIDB_STATEMENTS_STATS WHERE STMT_TYPE = 'Select' AND SNAP_ID = %?) summary_start " +
"on summary_end.DIGEST = summary_start.DIGEST"
)
// Handle The entry point for all workload-based learning related tasks
type Handle struct {
sysSessionPool util.DestroyableSessionPool
}
// NewWorkloadLearningHandle Create a new WorkloadLearningHandle
// WorkloadLearningHandle is Singleton pattern
func NewWorkloadLearningHandle(pool util.DestroyableSessionPool) *Handle {
return &Handle{pool}
}
// HandleTableReadCost Start a new round of analysis of all historical table read queries.
// According to abstracted table cost metrics, calculate the percentage of read scan time and memory usage for each table.
// The result will be saved to the table "mysql.tidb_workload_values".
// Dataflow
// 1. Abstract middle table cost metrics(scan time, memory usage, read frequency)
// from every record in statement_summary/statement_stats
//
// 2,3. Group by tablename, get the total scan time, total memory usage, and every table scan time, memory usage,
//
// read frequency
//
// 4. Calculate table cost for each table, table cost = table scan time / total scan time + table mem usage / total mem usage
// 5. Save all table cost metrics[per table](scan time, table cost, etc) to table "mysql.tidb_workload_values"
func (handle *Handle) HandleTableReadCost(infoSchema infoschema.InfoSchema) {
// step1: abstract middle table cost metrics from every record in statement_summary
tableIDToMetrics, startTime, endTime := handle.analyzeBasedOnStatementStats(infoSchema)
if tableIDToMetrics == nil {
logutil.BgLogger().Warn("Failed to analyze table cost metrics from statement stats, skip saving table read cost metrics")
return
}
// step5: save the table cost metrics to table "mysql.tidb_workload_values"
handle.SaveTableReadCostMetrics(tableIDToMetrics, startTime, endTime)
}
// analyzeBasedOnStatementStats Analyze table cost metrics based on the
func (handle *Handle) analyzeBasedOnStatementStats(infoSchema infoschema.InfoSchema) (tableIDToMetrics map[int64]*TableReadCostMetrics, startTime time.Time, endTime time.Time) {
// Calculate time range for the last 7 days
tableIDToMetrics = nil
endTime = time.Now()
startTime = endTime.AddDate(0, 0, -7)
// step1: create a new session, context for getting the records from ClusterTableTiDBStatementsStats/snapshot
se, err := handle.sysSessionPool.Get()
if err != nil {
logutil.BgLogger().Warn("get system session failed when fetch data from statements_stats", zap.Error(err))
return nil, startTime, endTime
}
defer func() {
if err == nil { // only recycle when no error
handle.sysSessionPool.Put(se)
} else {
// Note: Otherwise, the session will be leaked.
handle.sysSessionPool.Destroy(se)
}
}()
sctx := se.(sessionctx.Context)
exec := sctx.GetRestrictedSQLExecutor()
ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnWorkloadLearning)
// Step2: extract metrics from ClusterTableTiDBStatementsStats
// TODO: use constant for table name and field name/values
// TODO: control the cpu, memory , concurrency and timeout for the query
startSnapshotID, err := findClosestSnapshotIDByTime(ctx, exec, startTime)
if err != nil {
logutil.ErrVerboseLogger().Warn("Failed to query HIST_SNAPSHOTS table to find start snapshot id",
zap.Time("start_time", startTime),
zap.Error(err))
return nil, startTime, endTime
}
endSnapshotID, err := findClosestSnapshotIDByTime(ctx, exec, endTime)
if err != nil {
logutil.ErrVerboseLogger().Warn("Failed to query HIST_SNAPSHOTS table to find end snapshot id",
zap.Time("end_time", endTime),
zap.Error(err))
return nil, startTime, endTime
}
sql, err := sqlescape.EscapeSQL(workloadQueryTemplate, mysql.WorkloadSchema, endSnapshotID, mysql.WorkloadSchema, startSnapshotID)
if err != nil {
logutil.ErrVerboseLogger().Warn("Failed to construct SQL query for ClusterTableTiDBStatementsStats",
zap.Uint64("end_snapshot_id", endSnapshotID),
zap.Uint64("start_snapshot_id", startSnapshotID),
zap.Error(err))
return nil, startTime, endTime
}
// Step2.1: get statements stats record from ClusterTableTiDBStatementsStats
rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql)
if err != nil {
logutil.ErrVerboseLogger().Warn("Failed to query CLUSTER_TIDB_STATEMENTS_STATS table",
zap.Time("end_time", endTime),
zap.Error(err))
return nil, startTime, endTime
}
tableIDToMetrics = make(map[int64]*TableReadCostMetrics)
for _, row := range rows {
if row.IsNull(0) || row.IsNull(1) || row.IsNull(2) || row.IsNull(3) {
logutil.ErrVerboseLogger().Debug("Skip this row because there is null in the row")
continue
}
digest := row.GetString(0)
sql := row.GetString(1)
binaryPlan := row.GetString(2)
readFrequency := row.GetInt64(3)
// Step2.2: extract scan-time, memory-usage from planString
currentRecordMetrics, err := extractScanAndMemoryFromBinaryPlan(binaryPlan)
if err != nil {
logutil.ErrVerboseLogger().Warn("failed to abstract scan and memory from binary plan",
zap.String("digest", digest),
zap.String("sql", sql),
zap.Error(err))
continue
}
// Step2.3: accumulate all metrics from each record
AccumulateMetricsGroupByTableID(ctx, currentRecordMetrics, readFrequency, tableIDToMetrics, infoSchema)
}
// Step3: compute the table cost metrics by table scan time / total scan time + table mem usage / total mem usage
// Step3.1: compute total scan time and memory usage
// todo handle overflow
var totalScanTime, totalMemUsage int64
for _, metric := range tableIDToMetrics {
totalMemUsage += metric.TableMemUsage
totalScanTime = totalScanTime + metric.TableScanTime.Nanoseconds()
}
// Step3.2: compute table cost metrics
for _, metric := range tableIDToMetrics {
var scanTimePercentage, memoryUsagePercentage float64
if totalScanTime != 0 {
scanTimePercentage = float64(metric.TableScanTime.Nanoseconds()) / float64(totalScanTime)
}
if totalMemUsage != 0 {
memoryUsagePercentage = float64(metric.TableMemUsage) / float64(totalMemUsage)
}
metric.TableReadCost = scanTimePercentage + memoryUsagePercentage
}
return tableIDToMetrics, startTime, endTime
}
// findClosestSnapshotIDByTime finds the closest snapshot ID in the HIST_SNAPSHOTS table
// It means that the snapshot is a **past** snapshot whose time is before the given time and **closest** to the given time.
// Only find the **one closest** snapshot ID, so the result is a single uint64 value.
func findClosestSnapshotIDByTime(ctx context.Context, exec sqlexec.RestrictedSQLExecutor, time time.Time) (uint64, error) {
sql := new(strings.Builder)
sqlescape.MustFormatSQL(sql, "select SNAPSHOT_ID from %?.%? where BEGIN_TIME > %? and END_TIME != NULL "+
"order by (%? - BEGIN_TIME) asc limit 1",
mysql.WorkloadSchema, histSnapshotsTable,
time.AddDate(0, 0, -(int(defSnapshotInterval)*2)), time)
// Due to the time cost of snapshot worker, to ensure that a snapshot ID can be found within the "between and" time interval,
// the time interval is extended to within two past snapshot interval.
rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String())
if err != nil {
return 0, err
}
if len(rows) == 0 {
return 0, fmt.Errorf("no closest snapshot found by time in hist_snapshot table, time: %v", time)
}
for _, row := range rows {
snapshotID := row.GetUint64(0)
return snapshotID, nil
}
return 0, fmt.Errorf("failed to find closest snapshot by time in hist_snapshot table, time: %v", time)
}
// SaveTableReadCostMetrics table cost metrics, workload-based start and end time, version,
func (handle *Handle) SaveTableReadCostMetrics(metrics map[int64]*TableReadCostMetrics,
_, _ time.Time) {
// TODO save the workload job info such as start end time into workload_jobs table
// step1: create a new session, context, txn for saving table cost metrics
se, err := handle.sysSessionPool.Get()
if err != nil {
logutil.BgLogger().Warn("get system session failed when saving table cost metrics", zap.Error(err))
return
}
defer func() {
if err == nil { // only recycle when no error
handle.sysSessionPool.Put(se)
} else {
// Note: Otherwise, the session will be leaked.
handle.sysSessionPool.Destroy(se)
}
}()
sctx := se.(sessionctx.Context)
exec := sctx.GetRestrictedSQLExecutor()
ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnWorkloadLearning)
// begin a new txn
err = sessiontxn.NewTxn(context.Background(), sctx)
if err != nil {
logutil.BgLogger().Warn("get txn failed when saving table cost metrics", zap.Error(err))
return
}
txn, err := sctx.Txn(true)
if err != nil {
logutil.BgLogger().Warn("failed to get txn when saving table cost metrics", zap.Error(err))
return
}
// enable plan cache
sctx.GetSessionVars().EnableNonPreparedPlanCache = true
// step2: insert new version table cost metrics by batch using one common txn and context
version := txn.StartTS()
// build insert sql by batch(1000 tables)
i := 0
sql := new(strings.Builder)
sqlescape.MustFormatSQL(sql, "insert into mysql.tidb_workload_values (version, category, type, table_id, value) values ")
for tableID, metric := range metrics {
metricBytes, err := json.Marshal(metric)
if err != nil {
logutil.BgLogger().Warn("marshal table cost metrics failed",
zap.String("db_name", metric.DbName.String()),
zap.String("table_name", metric.TableName.String()),
zap.Duration("table_scan_time", metric.TableScanTime),
zap.Int64("table_mem_usage", metric.TableMemUsage),
zap.Int64("read_frequency", metric.ReadFrequency),
zap.Float64("table_read_cost", metric.TableReadCost),
zap.Error(err))
continue
}
sqlescape.MustFormatSQL(sql, "(%?, %?, %?, %?, %?)",
version, feedbackCategory, tableReadCost, tableID, json.RawMessage(metricBytes))
// TODO check the txn record limit
if i%batchInsertSize == batchInsertSize-1 {
_, _, err := exec.ExecRestrictedSQL(ctx, nil, sql.String())
if err != nil {
logutil.BgLogger().Warn("insert new version table cost metrics failed", zap.Error(err))
return
}
sql.Reset()
sql.WriteString("insert into mysql.tidb_workload_values (version, category, type, table_id, value) values ")
} else {
sql.WriteString(", ")
}
i++
}
// insert the last batch
if sql.Len() != 0 {
// remove the tail comma
sql := sql.String()[:sql.Len()-2]
_, _, err := exec.ExecRestrictedSQL(ctx, nil, sql)
if err != nil {
logutil.BgLogger().Warn("insert new version table cost metrics failed", zap.Error(err))
return
}
}
// step3: commit the txn, finish the save
err = txn.Commit(context.Background())
if err != nil {
logutil.BgLogger().Warn("commit txn failed when saving table cost metrics", zap.Error(err))
}
// TODO: saving the workload job info such as start end time into workload_jobs table
}
// extractScanAndMemoryFromBinaryPlan abstract the scan time and memory usage from one plan string
// The input is the field BINARY_PLAN in **one row** of table mysql.CLUSTER_TIDB_STATEMENTS_STATS
// The result should be **list** because every query maybe related to multiple tables such as: select * from t1, t2
// Every scan operator in the plan string should be extracted to one TableReadCostMetrics without any grouping
func extractScanAndMemoryFromBinaryPlan(binaryPlan string) ([]*TableReadCostMetrics, error) {
// Step1: convert binaryPlan to tipb.ExplainData
protoBytes, err := plancodec.Decompress(binaryPlan)
if err != nil {
return nil, err
}
explainData := &tipb.ExplainData{}
err = explainData.Unmarshal(protoBytes)
if err != nil {
return nil, err
}
if explainData.DiscardedDueToTooLong {
return nil, fmt.Errorf("plan is too long, discarded it, failed to convert to ExplainData")
}
if !explainData.WithRuntimeStats {
return nil, fmt.Errorf("plan is not with runtime stats, failed to convert to ExplainData")
}
// Step2: abstract the scan time and memory usage from explainData
operatorExtractMetrics := make([]*TableReadCostMetrics, 1)
// extract scan and memory from main part plan
operatorExtractMetrics, err = extractMetricsFromOperatorTree(explainData.Main, operatorExtractMetrics)
if err != nil {
return nil, err
}
// extract scan and memory from CTES part plan
for _, cte := range explainData.Ctes {
operatorExtractMetrics, err = extractMetricsFromOperatorTree(cte, operatorExtractMetrics)
if err != nil {
return nil, err
}
}
// extract scan and memory from SubQueries part plan
for _, subQ := range explainData.Subqueries {
operatorExtractMetrics, err = extractMetricsFromOperatorTree(subQ, operatorExtractMetrics)
if err != nil {
return nil, err
}
}
return operatorExtractMetrics, nil
}
// AccumulateMetricsGroupByTableID previousMetrics += currentRecordMetrics * frequency
// Step1: find the table id for all currentRecordMetrics
// Step2: Multiply the scan time and memory usage by frequency for all currentRecordMetrics
// Step3: Group by table id and sum with previous metrics
func AccumulateMetricsGroupByTableID(ctx context.Context, currentRecordMetrics []*TableReadCostMetrics, frequency int64, previousMetrics map[int64]*TableReadCostMetrics, infoSchema infoschema.InfoSchema) {
// Multiply exec count for each operator cpu and memory metrics
// Group by table id and sum the scan time and memory usage for each operatorExtractMetrics
// (TODO) The alias table name is recorded in operator instead of real table name. So we cannot handle the alias table name case now.
for _, operatorExtractMetric := range currentRecordMetrics {
// Step1: find the table id for all currentRecordMetrics
tbl, err := infoSchema.TableByName(ctx, operatorExtractMetric.DbName, operatorExtractMetric.TableName)
if err != nil {
logutil.BgLogger().Warn("failed to find the table id by table name and db name",
zap.String("db_name", operatorExtractMetric.DbName.String()),
zap.String("table_name", operatorExtractMetric.TableName.String()),
zap.Error(err))
continue
}
tableID := tbl.Meta().ID
// Step2: Multiply the scan time and memory usage by frequency
operatorExtractMetric.TableScanTime = time.Duration(operatorExtractMetric.TableScanTime.Nanoseconds() * frequency)
operatorExtractMetric.TableMemUsage *= frequency
operatorExtractMetric.ReadFrequency = frequency
// Step3: Group by table id and sum with previous metrics
if previousMetrics[tableID] == nil {
previousMetrics[tableID] = operatorExtractMetric
} else {
previousMetrics[tableID].TableScanTime += operatorExtractMetric.TableScanTime
previousMetrics[tableID].TableMemUsage += operatorExtractMetric.TableMemUsage
previousMetrics[tableID].ReadFrequency += operatorExtractMetric.ReadFrequency
}
}
}
func extractMetricsFromOperatorTree(op *tipb.ExplainOperator, operatorMetrics []*TableReadCostMetrics) ([]*TableReadCostMetrics, error) {
// Step1: extract operator metrics
// Step1.1: get the operator type from op.name
operatorType, err := extractOperatorTypeFromName(op.Name)
if err != nil {
return operatorMetrics, err
}
switch operatorType {
case plancodec.TypeIndexLookUp, plancodec.TypeIndexReader:
// Handle the case like:
// └─IndexLookUp_x | IndexReader_x
// └─xxx
// └─IndexRangeScan_x | IndexFullScan_x
// Get the time and memory in this layer and find the table name from child IndexXXXScan layer
memUsage := op.MemoryBytes
scanTime, err := extractScanTimeFromExecutionInfo(op)
if err != nil {
return operatorMetrics, fmt.Errorf("failed to extract scan time from execution info, operator name: %s", op.Name)
}
dbName, tableName := extractTableNameFromIndexScan(op)
if dbName == "" || tableName == "" {
return operatorMetrics, fmt.Errorf("failed to get table name from children index xxx operator, operator name: %s", op.Name)
}
operatorMetrics = append(operatorMetrics,
&TableReadCostMetrics{
DbName: ast.NewCIStr(dbName),
TableName: ast.NewCIStr(tableName),
TableScanTime: scanTime,
TableMemUsage: memUsage,
})
case plancodec.TypePointGet, plancodec.TypeBatchPointGet:
if len(op.AccessObjects) == 0 {
return operatorMetrics, fmt.Errorf("failed to get table name while access object is empty, operator name: %s", op.Name)
}
// Attention: cannot handle the multi access objects from one operator
dbName, tableName := extractTableNameFromAccessObject(op.AccessObjects[0])
if dbName == "" || tableName == "" {
return operatorMetrics, fmt.Errorf("failed to get table name from access object, operator name: %s, access object: %s", op.Name, op.AccessObjects[0].String())
}
scanTime, err := extractScanTimeFromExecutionInfo(op)
if err != nil {
return operatorMetrics, fmt.Errorf("failed to extract scan time from execution info, operator name: %s, err: %v", op.Name, err)
}
memUsage := defaultPointGetMemUsage
operatorMetrics = append(operatorMetrics,
&TableReadCostMetrics{
DbName: ast.NewCIStr(dbName),
TableName: ast.NewCIStr(tableName),
TableScanTime: scanTime,
TableMemUsage: int64(memUsage),
})
case plancodec.TypeTableReader:
// Ignore TiFlash query
isTiFlashOp := checkTiFlashOperator(op)
if isTiFlashOp {
return operatorMetrics, nil
}
// Handle the case like:
// └─TableReader_x
// └─xxx
// └─TableFullScan_x | TableRangeScan_x
// Get the time and memory in this layer and find the table name from child TableXXXScan layer
scanTime, err := extractScanTimeFromExecutionInfo(op)
if err != nil {
return operatorMetrics, fmt.Errorf("failed to extract scan time from execution info, operator name: %s", op.Name)
}
memUsage := op.MemoryBytes
dbName, tableName := extractTableNameFromChildrenTableScan(op)
if dbName == "" || tableName == "" {
return operatorMetrics, fmt.Errorf("failed to get table name from children table xxx operator, operator name: %s", op.Name)
}
operatorMetrics = append(operatorMetrics,
&TableReadCostMetrics{
DbName: ast.NewCIStr(dbName),
TableName: ast.NewCIStr(tableName),
TableScanTime: scanTime,
TableMemUsage: memUsage,
})
case plancodec.TypeIndexMerge:
// Handle the case like:
// └─IndexMerge_x
// └─xxx
// └─IndexxxxScan
// └─IndexxxxScan
// Get the memory in this layer and average memory to every single indexxxxscan
// Get time and table name from every child indexxxxscan layer
totalMemoryUsage := op.MemoryBytes
patialMetrics, err := extractPartialMetricsFromChildrenIndexMerge(op)
if err != nil || len(patialMetrics) == 0 {
return operatorMetrics, fmt.Errorf("failed to get time and table name from children operator of index merge, operator name: %s", op.Name)
}
memUsage := totalMemoryUsage / int64(len(patialMetrics))
for _, patialMetric := range patialMetrics {
patialMetric.TableMemUsage = memUsage
operatorMetrics = append(operatorMetrics, patialMetric)
}
default:
// TODO: extand the operator type indexjoin, indexmergejoin, tiflash
}
// Step2: Recursively extract the children layer
if len(op.Children) == 0 {
return operatorMetrics, nil
}
for _, child := range op.Children {
operatorMetrics, err = extractMetricsFromOperatorTree(child, operatorMetrics)
if err != nil {
return operatorMetrics, err
}
}
return operatorMetrics, nil
}
func checkTiFlashOperator(op *tipb.ExplainOperator) bool {
if len(op.Children) == 0 {
return false
}
for _, child := range op.Children {
if child.TaskType == tipb.TaskType_mpp {
return true
}
checkTiFlashOperator(child) // Recursive check
}
return false
}
func extractTableNameFromAccessObject(accessObject *tipb.AccessObject) (dbName string, tableName string) {
switch ao := accessObject.AccessObject.(type) {
case *tipb.AccessObject_DynamicPartitionObjects:
if ao == nil || ao.DynamicPartitionObjects == nil {
return "", ""
}
aos := ao.DynamicPartitionObjects.Objects
if len(aos) == 0 {
return "", ""
}
// If it involves multiple tables, we also need to print the table name.
for _, access := range aos {
if access == nil {
continue
}
return access.Database, access.Table
}
case *tipb.AccessObject_ScanObject:
if ao == nil || ao.ScanObject == nil {
return "", ""
}
return ao.ScanObject.Database, ao.ScanObject.Table
}
return "", ""
}
func extractTableNameFromChildrenTableScan(op *tipb.ExplainOperator) (dbName string, tableName string) {
if len(op.Children) == 0 {
return "", ""
}
for _, child := range op.Children {
childOT, err := extractOperatorTypeFromName(child.Name)
if err != nil {
return "", ""
}
if childOT == plancodec.TypeTableFullScan || childOT == plancodec.TypeTableRangeScan {
dbName, tableName = extractTableNameFromAccessObject(child.AccessObjects[0])
return dbName, tableName
}
// Recursive search for the children layer
dbName, tableName = extractTableNameFromChildrenTableScan(child)
if dbName != "" && tableName != "" {
return dbName, tableName
}
}
return "", ""
}
func extractTableNameFromIndexScan(op *tipb.ExplainOperator) (dbName string, tableName string) {
if len(op.Children) == 0 {
return "", ""
}
for _, child := range op.Children {
childOT, err := extractOperatorTypeFromName(child.Name)
if err != nil {
return "", ""
}
if childOT == plancodec.TypeIndexRangeScan || childOT == plancodec.TypeIndexFullScan {
dbName, tableName = extractTableNameFromAccessObject(child.AccessObjects[0])
return dbName, tableName
}
// Recursive search for the children layer
dbName, tableName = extractTableNameFromIndexScan(child)
if dbName != "" && tableName != "" {
return dbName, tableName
}
}
return "", ""
}
func extractPartialMetricsFromChildrenIndexMerge(op *tipb.ExplainOperator) ([]*TableReadCostMetrics, error) {
if len(op.Children) == 0 {
return nil, nil
}
// Find the layer of IndexxxxScan
children0OperatorType, err := extractOperatorTypeFromName(op.Children[0].Name)
if err != nil {
return nil, err
}
if children0OperatorType == plancodec.TypeIndexRangeScan || children0OperatorType == plancodec.TypeIndexFullScan {
result := make([]*TableReadCostMetrics, 0, len(op.Children))
for _, child := range op.Children {
dbName, tableName := extractTableNameFromAccessObject(child.AccessObjects[0])
if dbName == "" || tableName == "" {
return nil, fmt.Errorf("failed to get table name from children index xxx operator, operator name: %s", child.Name)
}
scanTime, err := extractScanTimeFromExecutionInfo(child)
if err != nil {
return nil, fmt.Errorf("failed to extract scan time from execution info, operator name: %s", child.Name)
}
partialMetric := &TableReadCostMetrics{
DbName: ast.NewCIStr(dbName),
TableName: ast.NewCIStr(tableName),
TableScanTime: scanTime,
}
result = append(result, partialMetric)
}
return result, nil
}
// Recursively find the children layer
var result []*TableReadCostMetrics
for _, child := range op.Children {
result, err = extractPartialMetricsFromChildrenIndexMerge(child)
if len(result) != 0 {
return result, nil
}
}
return nil, err
}
// extractOperatorTypeFromName extracts the operator type from the operator name.
// The operator name pattern seems to be like "OperatorType_ID"
// For example, "TableReader_1", "IndexLookUp_2", "PointGet_3", etc.
// The function will return the operator type part, such as "TableReader", "IndexLookUp", "PointGet", etc.
func extractOperatorTypeFromName(name string) (string, error) {
names := strings.Split(name, "_")
if len(names) != 2 {
return "", fmt.Errorf("failed to extract operator type from operator name: %s", name)
}
return names[0], nil
}
func extractScanTimeFromExecutionInfo(op *tipb.ExplainOperator) (time.Duration, error) {
var scanTime time.Duration
var err error
if len(op.RootBasicExecInfo) > 0 {
scanTime, err = extractScanTimeFromString(op.RootBasicExecInfo)
}
if scanTime == 0 && len(op.RootGroupExecInfo) > 0 {
// try groupExecInfo instead
scanTime, err = extractScanTimeFromString(op.RootGroupExecInfo[0])
}
if scanTime == 0 && len(op.CopExecInfo) > 0 {
// try copExecInfo instead
scanTime, err = extractScanTimeFromString(op.CopExecInfo)
}
return scanTime, err
}
// extractScanTimeFromString extracts the scan time from execution info string that contains "time: <duration>,"
// The execution info string pattern seems to be like:
// "time:274.5µs, loops:1, cop_task:......"
// We need the duration after "time:" and before the first comma.
func extractScanTimeFromString(input string) (time.Duration, error) {
start := strings.Index(input, "time:")
if start == -1 {
return 0, fmt.Errorf("'time:' not found in input string")
}
start += len("time:") // move start index after "time:"
end := strings.Index(input[start:], ",")
if end == -1 {
return 0, fmt.Errorf("',' not found after 'time:'")
}
end += start // compute the end index
// extract timeStr excluding "time:" and "," and empty str
timeStr := strings.TrimSpace(input[start:end])
// parse str to Duration
duration, err := time.ParseDuration(timeStr)
if err != nil {
return 0, fmt.Errorf("failed to parse duration: %v", err)
}
return duration, nil
}
// DBNameExtractor is an AST visitor that extracts all database names from query.
// It is used to check whether the query is cross-database query or not.
type DBNameExtractor struct {
// DBs is a map to store the database names found in the AST nodes.
DBs map[string]struct{}
}
// Enter is called when entering a node in the AST.
// It extracts the database name from TableName AST node and saving it in DBs map.
func (e *DBNameExtractor) Enter(n ast.Node) (ast.Node, bool) {
if table, ok := n.(*ast.TableName); ok {
if e.DBs == nil {
e.DBs = make(map[string]struct{})
}
e.DBs[table.Schema.L] = struct{}{}
}
return n, false
}
// Leave is called when leaving a node in the AST.
func (*DBNameExtractor) Leave(n ast.Node) (ast.Node, bool) {
return n, true
}