447 lines
14 KiB
Go
447 lines
14 KiB
Go
// Copyright 2023 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package storage
|
|
|
|
import (
|
|
"context"
|
|
"strconv"
|
|
"strings"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/ngaut/pools"
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/failpoint"
|
|
"github.com/pingcap/tidb/disttask/framework/proto"
|
|
"github.com/pingcap/tidb/kv"
|
|
"github.com/pingcap/tidb/parser/terror"
|
|
"github.com/pingcap/tidb/sessionctx"
|
|
"github.com/pingcap/tidb/util/chunk"
|
|
"github.com/pingcap/tidb/util/logutil"
|
|
"github.com/pingcap/tidb/util/sqlexec"
|
|
"github.com/tikv/client-go/v2/util"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// TaskManager is the manager of global/sub task.
|
|
type TaskManager struct {
|
|
ctx context.Context
|
|
sePool *pools.ResourcePool
|
|
}
|
|
|
|
var taskManagerInstance atomic.Pointer[TaskManager]
|
|
|
|
var (
|
|
// TestLastTaskID is used for test to set the last task ID.
|
|
TestLastTaskID atomic.Int64
|
|
)
|
|
|
|
// NewTaskManager creates a new task manager.
|
|
func NewTaskManager(ctx context.Context, sePool *pools.ResourcePool) *TaskManager {
|
|
ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask)
|
|
return &TaskManager{
|
|
ctx: ctx,
|
|
sePool: sePool,
|
|
}
|
|
}
|
|
|
|
// GetTaskManager gets the task manager.
|
|
func GetTaskManager() (*TaskManager, error) {
|
|
v := taskManagerInstance.Load()
|
|
if v == nil {
|
|
return nil, errors.New("global task manager is not initialized")
|
|
}
|
|
return v, nil
|
|
}
|
|
|
|
// SetTaskManager sets the task manager.
|
|
func SetTaskManager(is *TaskManager) {
|
|
taskManagerInstance.Store(is)
|
|
}
|
|
|
|
// execSQL executes the sql and returns the result.
|
|
// TODO: consider retry.
|
|
func execSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
|
|
rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if rs != nil {
|
|
defer terror.Call(rs.Close)
|
|
return sqlexec.DrainRecordSet(ctx, rs, 1024)
|
|
}
|
|
return nil, nil
|
|
}
|
|
|
|
// row2GlobeTask converts a row to a global task.
|
|
func row2GlobeTask(r chunk.Row) *proto.Task {
|
|
task := &proto.Task{
|
|
ID: r.GetInt64(0),
|
|
Key: r.GetString(1),
|
|
Type: r.GetString(2),
|
|
DispatcherID: r.GetString(3),
|
|
State: r.GetString(4),
|
|
Meta: r.GetBytes(7),
|
|
Concurrency: uint64(r.GetInt64(8)),
|
|
Step: r.GetInt64(9),
|
|
Error: r.GetBytes(10),
|
|
}
|
|
// TODO: convert to local time.
|
|
task.StartTime, _ = r.GetTime(5).GoTime(time.UTC)
|
|
task.StateUpdateTime, _ = r.GetTime(6).GoTime(time.UTC)
|
|
return task
|
|
}
|
|
|
|
// WithNewSession executes the function with a new session.
|
|
func (stm *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error {
|
|
se, err := stm.sePool.Get()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer stm.sePool.Put(se)
|
|
return fn(se.(sessionctx.Context))
|
|
}
|
|
|
|
func (stm *TaskManager) withNewTxn(fn func(se sessionctx.Context) error) error {
|
|
return stm.WithNewSession(func(se sessionctx.Context) (err error) {
|
|
_, err = execSQL(stm.ctx, se, "begin")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
success := false
|
|
defer func() {
|
|
sql := "rollback"
|
|
if success {
|
|
sql = "commit"
|
|
}
|
|
_, commitErr := execSQL(stm.ctx, se, sql)
|
|
if err == nil && commitErr != nil {
|
|
err = commitErr
|
|
}
|
|
}()
|
|
|
|
if err = fn(se); err != nil {
|
|
return err
|
|
}
|
|
|
|
success = true
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (stm *TaskManager) executeSQLWithNewSession(ctx context.Context, sql string, args ...interface{}) (rs []chunk.Row, err error) {
|
|
err = stm.WithNewSession(func(se sessionctx.Context) error {
|
|
rs, err = execSQL(ctx, se, sql, args...)
|
|
return err
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// AddNewGlobalTask adds a new task to global task table.
|
|
func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta []byte) (taskID int64, err error) {
|
|
err = stm.WithNewSession(func(se sessionctx.Context) error {
|
|
_, err = execSQL(stm.ctx, se, "insert into mysql.tidb_global_task(task_key, type, state, concurrency, step, meta, state_update_time) values (%?, %?, %?, %?, %?, %?, %?)", key, tp, proto.TaskStatePending, concurrency, proto.StepInit, meta, time.Now().UTC().String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
rs, err := execSQL(stm.ctx, se, "select @@last_insert_id")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
taskID, err = strconv.ParseInt(rs[0].GetString(0), 10, 64)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) })
|
|
|
|
return nil
|
|
})
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return
|
|
}
|
|
|
|
// GetNewGlobalTask get a new task from global task table, it's used by dispatcher only.
|
|
func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending)
|
|
if err != nil {
|
|
return task, err
|
|
}
|
|
|
|
if len(rs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return row2GlobeTask(rs[0]), nil
|
|
}
|
|
|
|
// GetGlobalTasksInStates gets the tasks in the states.
|
|
func (stm *TaskManager) GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) {
|
|
if len(states) == 0 {
|
|
return task, nil
|
|
}
|
|
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...)
|
|
if err != nil {
|
|
return task, err
|
|
}
|
|
|
|
for _, r := range rs {
|
|
task = append(task, row2GlobeTask(r))
|
|
}
|
|
return task, nil
|
|
}
|
|
|
|
// GetGlobalTaskByID gets the task by the global task ID.
|
|
func (stm *TaskManager) GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where id = %?", taskID)
|
|
if err != nil {
|
|
return task, err
|
|
}
|
|
if len(rs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return row2GlobeTask(rs[0]), nil
|
|
}
|
|
|
|
// GetGlobalTaskByKey gets the task by the task key
|
|
func (stm *TaskManager) GetGlobalTaskByKey(key string) (task *proto.Task, err error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where task_key = %?", key)
|
|
if err != nil {
|
|
return task, err
|
|
}
|
|
if len(rs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return row2GlobeTask(rs[0]), nil
|
|
}
|
|
|
|
// row2SubTask converts a row to a subtask.
|
|
func row2SubTask(r chunk.Row) *proto.Subtask {
|
|
task := &proto.Subtask{
|
|
ID: r.GetInt64(0),
|
|
Step: r.GetInt64(1),
|
|
Type: proto.Int2Type(int(r.GetInt64(5))),
|
|
SchedulerID: r.GetString(6),
|
|
State: r.GetString(8),
|
|
Meta: r.GetBytes(12),
|
|
StartTime: r.GetUint64(10),
|
|
}
|
|
tid, err := strconv.Atoi(r.GetString(3))
|
|
if err != nil {
|
|
logutil.BgLogger().Warn("unexpected task ID", zap.String("task ID", r.GetString(3)))
|
|
}
|
|
task.TaskID = int64(tid)
|
|
return task
|
|
}
|
|
|
|
// AddNewSubTask adds a new task to subtask table.
|
|
func (stm *TaskManager) AddNewSubTask(globalTaskID int64, step int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error {
|
|
st := proto.TaskStatePending
|
|
if isRevert {
|
|
st = proto.TaskStateRevertPending
|
|
}
|
|
|
|
_, err := stm.executeSQLWithNewSession(stm.ctx, "insert into mysql.tidb_background_subtask(task_key, step, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", globalTaskID, step, designatedTiDBID, meta, st, proto.Type2Int(tp), []byte{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSubtaskInStates gets the subtask in the states.
|
|
func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) {
|
|
args := []interface{}{tidbID, taskID}
|
|
args = append(args, states...)
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
return row2SubTask(rs[0]), nil
|
|
}
|
|
|
|
// GetSucceedSubtasksByStep gets the subtask in the success state.
|
|
func (stm *TaskManager) GetSucceedSubtasksByStep(taskID int64, step int64) ([]*proto.Subtask, error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where task_key = %? and state = %? and step = %?", taskID, proto.TaskStateSucceed, step)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rs) == 0 {
|
|
return nil, nil
|
|
}
|
|
subtasks := make([]*proto.Subtask, 0, len(rs))
|
|
for _, r := range rs {
|
|
subtasks = append(subtasks, row2SubTask(r))
|
|
}
|
|
return subtasks, nil
|
|
}
|
|
|
|
// GetSubtaskInStatesCnt gets the subtask count in the states.
|
|
func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) {
|
|
args := []interface{}{taskID}
|
|
args = append(args, states...)
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select count(*) from mysql.tidb_background_subtask where task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return rs[0].GetInt64(0), nil
|
|
}
|
|
|
|
// CollectSubTaskError collects the subtask error.
|
|
func (stm *TaskManager) CollectSubTaskError(taskID int64) ([][]byte, error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select error from mysql.tidb_background_subtask where task_key = %? AND state = %?", taskID, proto.TaskStateFailed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
subTaskErrors := make([][]byte, 0, len(rs))
|
|
for _, err := range rs {
|
|
subTaskErrors = append(subTaskErrors, err.GetBytes(0))
|
|
}
|
|
|
|
return subTaskErrors, nil
|
|
}
|
|
|
|
// HasSubtasksInStates checks if there are subtasks in the states.
|
|
func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) {
|
|
args := []interface{}{tidbID, taskID}
|
|
args = append(args, states...)
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return len(rs) > 0, nil
|
|
}
|
|
|
|
// UpdateSubtaskStateAndError updates the subtask state.
|
|
func (stm *TaskManager) UpdateSubtaskStateAndError(id int64, state string, subTaskErr string) error {
|
|
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set state = %?, error = %? where id = %?", state, subTaskErr, id)
|
|
return err
|
|
}
|
|
|
|
// FinishSubtask updates the subtask meta and mark state to succeed.
|
|
func (stm *TaskManager) FinishSubtask(id int64, meta []byte) error {
|
|
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set meta = %?, state = %? where id = %?", meta, proto.TaskStateSucceed, id)
|
|
return err
|
|
}
|
|
|
|
// UpdateSubtaskHeartbeat updates the heartbeat of the subtask.
|
|
func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error {
|
|
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set exec_expired = %? where exec_id = %? and task_key = %?", heartbeat.String(), instanceID, taskID)
|
|
return err
|
|
}
|
|
|
|
// DeleteSubtasksByTaskID deletes the subtask of the given global task ID.
|
|
func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error {
|
|
_, err := stm.executeSQLWithNewSession(stm.ctx, "delete from mysql.tidb_background_subtask where task_key = %?", taskID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetSchedulerIDsByTaskID gets the scheduler IDs of the given global task ID.
|
|
func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select distinct(exec_id) from mysql.tidb_background_subtask where task_key = %?", taskID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(rs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
instanceIDs := make([]string, 0, len(rs))
|
|
for _, r := range rs {
|
|
id := r.GetString(0)
|
|
instanceIDs = append(instanceIDs, id)
|
|
}
|
|
|
|
return instanceIDs, nil
|
|
}
|
|
|
|
// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks
|
|
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, isSubtaskRevert bool) error {
|
|
return stm.withNewTxn(func(se sessionctx.Context) error {
|
|
_, err := execSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?",
|
|
gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, gTask.Error, gTask.ID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) {
|
|
if val.(bool) {
|
|
failpoint.Return(errors.New("updateTaskErr"))
|
|
}
|
|
})
|
|
|
|
subtaskState := proto.TaskStatePending
|
|
if isSubtaskRevert {
|
|
subtaskState = proto.TaskStateRevertPending
|
|
}
|
|
|
|
for _, subtask := range subtasks {
|
|
// TODO: insert subtasks in batch
|
|
_, err = execSQL(stm.ctx, se, "insert into mysql.tidb_background_subtask(step, task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)",
|
|
gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
}
|
|
|
|
// CancelGlobalTask cancels global task
|
|
func (stm *TaskManager) CancelGlobalTask(taskID int64) error {
|
|
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_global_task set state=%? where id=%? and state in (%?, %?)",
|
|
proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// IsGlobalTaskCancelling checks whether the task state is cancelling
|
|
func (stm *TaskManager) IsGlobalTaskCancelling(taskID int64) (bool, error) {
|
|
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?",
|
|
taskID, proto.TaskStateCancelling,
|
|
)
|
|
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
|
|
return len(rs) > 0, nil
|
|
}
|