Files
tidb/disttask/framework/storage/task_table.go

358 lines
10 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)
// GlobalTaskManager is the manager of global task.
type GlobalTaskManager struct {
ctx context.Context
se sessionctx.Context
mu sync.Mutex
}
var globalTaskManagerInstance atomic.Pointer[GlobalTaskManager]
var subTaskManagerInstance atomic.Pointer[SubTaskManager]
// NewGlobalTaskManager creates a new global task manager.
func NewGlobalTaskManager(ctx context.Context, se sessionctx.Context) *GlobalTaskManager {
ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask)
return &GlobalTaskManager{
ctx: ctx,
se: se,
}
}
// NewSubTaskManager creates a new sub task manager.
func NewSubTaskManager(ctx context.Context, se sessionctx.Context) *SubTaskManager {
ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask)
return &SubTaskManager{
ctx: ctx,
se: se,
}
}
// GetGlobalTaskManager gets the global task manager.
func GetGlobalTaskManager() (*GlobalTaskManager, error) {
v := globalTaskManagerInstance.Load()
if v == nil {
return nil, errors.New("global task manager is not initialized")
}
return v, nil
}
// SetGlobalTaskManager sets the global task manager.
func SetGlobalTaskManager(is *GlobalTaskManager) {
globalTaskManagerInstance.Store(is)
}
// GetSubTaskManager gets the sub task manager.
func GetSubTaskManager() (*SubTaskManager, error) {
v := subTaskManagerInstance.Load()
if v == nil {
return nil, errors.New("subTask manager is not initialized")
}
return v, nil
}
// SetSubTaskManager sets the sub task manager.
func SetSubTaskManager(is *SubTaskManager) {
subTaskManagerInstance.Store(is)
}
// execSQL executes the sql and returns the result.
// TODO: consider retry.
func execSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...)
if err != nil {
return nil, err
}
if rs != nil {
rows, err := sqlexec.DrainRecordSet(ctx, rs, 1)
if err != nil {
return nil, err
}
err = rs.Close()
if err != nil {
return nil, err
}
return rows, err
}
return nil, nil
}
// row2GlobeTask converts a row to a global task.
func row2GlobeTask(r chunk.Row) *proto.Task {
task := &proto.Task{
ID: r.GetInt64(0),
Type: r.GetString(1),
DispatcherID: r.GetString(2),
State: r.GetString(3),
Meta: r.GetBytes(6),
Concurrency: uint64(r.GetInt64(7)),
Step: r.GetInt64(8),
}
// TODO: convert to local time.
task.StartTime, _ = r.GetTime(4).GoTime(time.UTC)
task.StateUpdateTime, _ = r.GetTime(5).GoTime(time.UTC)
return task
}
// AddNewTask adds a new task to global task table.
func (stm *GlobalTaskManager) AddNewTask(tp string, concurrency int, meta []byte) (int64, error) {
stm.mu.Lock()
defer stm.mu.Unlock()
_, err := execSQL(stm.ctx, stm.se, "insert into mysql.tidb_global_task(type, state, concurrency, meta, state_update_time) values (%?, %?, %?, %?, %?)", tp, proto.TaskStatePending, concurrency, meta, time.Now().UTC().String())
if err != nil {
return 0, err
}
rs, err := execSQL(stm.ctx, stm.se, "select @@last_insert_id")
if err != nil {
return 0, err
}
return strconv.ParseInt(rs[0].GetString(0), 10, 64)
}
// GetNewTask get a new task from global task table, it's used by dispatcher only.
func (stm *GlobalTaskManager) GetNewTask() (task *proto.Task, err error) {
stm.mu.Lock()
defer stm.mu.Unlock()
rs, err := execSQL(stm.ctx, stm.se, "select * from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// UpdateTask updates the global task.
func (stm *GlobalTaskManager) UpdateTask(task *proto.Task) error {
stm.mu.Lock()
defer stm.mu.Unlock()
_, err := execSQL(stm.ctx, stm.se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %? where id = %?",
task.State, task.DispatcherID, task.Step, task.StateUpdateTime.UTC().String(), task.Concurrency, task.ID)
if err != nil {
return err
}
return nil
}
// GetTasksInStates gets the tasks in the states.
func (stm *GlobalTaskManager) GetTasksInStates(states ...interface{}) (task []*proto.Task, err error) {
stm.mu.Lock()
defer stm.mu.Unlock()
if len(states) == 0 {
return task, nil
}
rs, err := execSQL(stm.ctx, stm.se, "select * from mysql.tidb_global_task where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...)
if err != nil {
return task, err
}
for _, r := range rs {
task = append(task, row2GlobeTask(r))
}
return task, nil
}
// GetTaskByID gets the task by the global task ID.
func (stm *GlobalTaskManager) GetTaskByID(taskID int64) (task *proto.Task, err error) {
stm.mu.Lock()
defer stm.mu.Unlock()
rs, err := execSQL(stm.ctx, stm.se, "select * from mysql.tidb_global_task where id = %?", taskID)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// SubTaskManager is the manager of subtask.
type SubTaskManager struct {
ctx context.Context
se sessionctx.Context
mu sync.Mutex
}
// row2SubTask converts a row to a subtask.
func row2SubTask(r chunk.Row) *proto.Subtask {
task := &proto.Subtask{
ID: r.GetInt64(0),
Type: proto.Int2Type(int(r.GetInt64(4))),
SchedulerID: r.GetString(5),
State: r.GetString(7),
Meta: r.GetBytes(11),
StartTime: r.GetUint64(9),
}
tid, err := strconv.Atoi(r.GetString(2))
if err != nil {
logutil.BgLogger().Warn("unexpected task ID", zap.String("task ID", r.GetString(2)))
}
task.TaskID = int64(tid)
return task
}
// AddNewTask adds a new task to subtask table.
func (stm *SubTaskManager) AddNewTask(globalTaskID int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error {
stm.mu.Lock()
defer stm.mu.Unlock()
st := proto.TaskStatePending
if isRevert {
st = proto.TaskStateRevertPending
}
_, err := execSQL(stm.ctx, stm.se, "insert into mysql.tidb_background_subtask(task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?)", globalTaskID, designatedTiDBID, meta, st, proto.Type2Int(tp), []byte{})
if err != nil {
return err
}
return nil
}
// GetSubtaskInStates gets the subtask in the states.
func (stm *SubTaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) {
stm.mu.Lock()
defer stm.mu.Unlock()
args := []interface{}{tidbID, taskID}
args = append(args, states...)
rs, err := execSQL(stm.ctx, stm.se, "select * from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
return row2SubTask(rs[0]), nil
}
// GetSubtaskInStatesCnt gets the subtask count in the states.
func (stm *SubTaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) {
stm.mu.Lock()
defer stm.mu.Unlock()
args := []interface{}{taskID}
args = append(args, states...)
rs, err := execSQL(stm.ctx, stm.se, "select count(*) from mysql.tidb_background_subtask where task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
if err != nil {
return 0, err
}
return rs[0].GetInt64(0), nil
}
// HasSubtasksInStates checks if there are subtasks in the states.
func (stm *SubTaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) {
stm.mu.Lock()
defer stm.mu.Unlock()
args := []interface{}{tidbID, taskID}
args = append(args, states...)
rs, err := execSQL(stm.ctx, stm.se, "select 1 from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}
// UpdateSubtaskState updates the subtask state.
func (stm *SubTaskManager) UpdateSubtaskState(id int64, state string) error {
stm.mu.Lock()
defer stm.mu.Unlock()
_, err := execSQL(stm.ctx, stm.se, "update mysql.tidb_background_subtask set state = %? where id = %?", state, id)
return err
}
// UpdateHeartbeat updates the heartbeat of the subtask.
func (stm *SubTaskManager) UpdateHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error {
stm.mu.Lock()
defer stm.mu.Unlock()
_, err := execSQL(stm.ctx, stm.se, "update mysql.tidb_background_subtask set exec_expired = %? where exec_id = %? and task_key = %?", heartbeat.String(), instanceID, taskID)
return err
}
// DeleteTasks deletes the subtask of the given global task ID.
func (stm *SubTaskManager) DeleteTasks(taskID int64) error {
stm.mu.Lock()
defer stm.mu.Unlock()
_, err := execSQL(stm.ctx, stm.se, "delete from mysql.tidb_background_subtask where task_key = %?", taskID)
if err != nil {
return err
}
return nil
}
// GetSchedulerIDs gets the scheduler IDs of the given global task ID.
func (stm *SubTaskManager) GetSchedulerIDs(taskID int64) ([]string, error) {
stm.mu.Lock()
defer stm.mu.Unlock()
rs, err := execSQL(stm.ctx, stm.se, "select distinct(exec_id) from mysql.tidb_background_subtask where task_key = %?", taskID)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
instanceIDs := make([]string, 0, len(rs))
for _, r := range rs {
id := r.GetString(0)
instanceIDs = append(instanceIDs, id)
}
return instanceIDs, nil
}