Files
tidb/pkg/workloadlearning/cache.go

155 lines
5.4 KiB
Go

// Copyright 2025 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package workloadlearning
import (
"context"
"encoding/json"
"sync"
"github.com/pingcap/tidb/pkg/kv"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)
// TableReadCostCache stores the cached workload learning metrics
type TableReadCostCache struct {
TableReadCostMetrics map[int64]*TableReadCostMetrics // key: tableID
Version uint64
}
// WLCacheWorker the worker to cache all workload-related metrics
// Now it is also used to save the cache data of table cost metrics.
type WLCacheWorker struct {
sysSessionPool util.DestroyableSessionPool
tableReadCostCache *TableReadCostCache
sync.RWMutex
}
// NewWLCacheWorker Create a new workload learning cache worker to cache all workload-related metrics
// from storage mysql.tidb_workload_values to memory
func NewWLCacheWorker(pool util.DestroyableSessionPool) *WLCacheWorker {
cache := &TableReadCostCache{
TableReadCostMetrics: make(map[int64]*TableReadCostMetrics),
Version: 0,
}
return &WLCacheWorker{
pool, cache, sync.RWMutex{}}
}
// UpdateTableReadCostCache refreshes the cached workload learning metrics
func (cw *WLCacheWorker) UpdateTableReadCostCache() {
// Get latest metrics from storage
se, err := cw.sysSessionPool.Get()
if err != nil {
logutil.BgLogger().Warn("Get system session failed when updating table cost cache", zap.Error(err))
return
}
defer func() {
if err == nil { // only recycle when no error
cw.sysSessionPool.Put(se)
} else {
// Note: Otherwise, the session will be leaked.
cw.sysSessionPool.Destroy(se)
}
}()
sctx := se.(sessionctx.Context)
exec := sctx.GetRestrictedSQLExecutor()
ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnWorkloadLearning)
// Whether to update table cost metrics
// Get the latest latestVersionInStorage in the storage
// TODO(elsa): Add the index of (category, type, version) to mysql.tidb_workload_values
sql := `SELECT version FROM mysql.tidb_workload_values
WHERE category = %? AND type = %?
ORDER BY version DESC LIMIT 1`
rows, _, err := exec.ExecRestrictedSQL(ctx, nil, sql, feedbackCategory, tableReadCost)
if err != nil {
logutil.ErrVerboseLogger().Warn("Failed to get the latest table cost version", zap.Error(err))
return
}
// Case: no metrics belongs to this feedback category and type
if len(rows) != 1 {
logutil.BgLogger().Warn("The result of latest table cost version query is not 1",
zap.Int("result_rows", len(rows)))
return
}
// If the latest latestVersionInStorage is the same as the cached latestVersionInStorage, no need to update
latestVersionInStorage := rows[0].GetUint64(0)
if latestVersionInStorage <= cw.tableReadCostCache.Version {
logutil.BgLogger().Info("The latest table cost version in storage is the same as the cached version, no need to update",
zap.Uint64("latest_version_in_storage", latestVersionInStorage),
zap.Uint64("cached_version", cw.tableReadCostCache.Version))
return
}
// Get the latest table cost of metrics
sql = `SELECT table_id, value FROM mysql.tidb_workload_values
WHERE category = %? AND type = %? AND version = %?`
rows, _, err = exec.ExecRestrictedSQL(ctx, nil, sql, feedbackCategory, tableReadCost, latestVersionInStorage)
if err != nil {
logutil.ErrVerboseLogger().Warn("Failed to get the latest table cost metrics",
zap.Error(err))
return
}
newMetrics := make(map[int64]*TableReadCostMetrics)
for _, row := range rows {
tableID := row.GetInt64(0)
value := row.GetJSON(1).String()
metric := &TableReadCostMetrics{}
if err := json.Unmarshal([]byte(value), metric); err != nil {
logutil.ErrVerboseLogger().Warn("Failed to unmarshal table cost metrics",
zap.Int64("table_id", tableID),
zap.String("value", value),
zap.Error(err))
continue
}
newMetrics[tableID] = metric
}
// Update cache atomically
cw.updateTableReadCostCacheWithMetrics(newMetrics, latestVersionInStorage)
}
func (cw *WLCacheWorker) updateTableReadCostCacheWithMetrics(newMetrics map[int64]*TableReadCostMetrics,
latestVersionInStorage uint64) {
cw.RWMutex.Lock()
defer cw.RWMutex.Unlock()
cw.tableReadCostCache.TableReadCostMetrics = newMetrics
cw.tableReadCostCache.Version = latestVersionInStorage
}
// GetTableReadCostMetrics returns the cached metrics for a given table ID
func (cw *WLCacheWorker) GetTableReadCostMetrics(tableID int64) *TableReadCostMetrics {
cw.RWMutex.RLock()
defer cw.RWMutex.RUnlock()
metric, exists := cw.tableReadCostCache.TableReadCostMetrics[tableID]
if !exists {
return nil
}
// deep copy for metrics to protect the cache
result := &TableReadCostMetrics{
TableScanTime: metric.TableScanTime,
TableMemUsage: metric.TableMemUsage,
ReadFrequency: metric.ReadFrequency,
TableReadCost: metric.TableReadCost,
}
return result
}