Files
tidb/disttask/framework/storage/task_table.go

447 lines
14 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/ngaut/pools"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)
// TaskManager is the manager of global/sub task.
type TaskManager struct {
ctx context.Context
sePool *pools.ResourcePool
}
var taskManagerInstance atomic.Pointer[TaskManager]
var (
// TestLastTaskID is used for test to set the last task ID.
TestLastTaskID atomic.Int64
)
// NewTaskManager creates a new task manager.
func NewTaskManager(ctx context.Context, sePool *pools.ResourcePool) *TaskManager {
ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask)
return &TaskManager{
ctx: ctx,
sePool: sePool,
}
}
// GetTaskManager gets the task manager.
func GetTaskManager() (*TaskManager, error) {
v := taskManagerInstance.Load()
if v == nil {
return nil, errors.New("global task manager is not initialized")
}
return v, nil
}
// SetTaskManager sets the task manager.
func SetTaskManager(is *TaskManager) {
taskManagerInstance.Store(is)
}
// execSQL executes the sql and returns the result.
// TODO: consider retry.
func execSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...)
if err != nil {
return nil, err
}
if rs != nil {
defer terror.Call(rs.Close)
return sqlexec.DrainRecordSet(ctx, rs, 1024)
}
return nil, nil
}
// row2GlobeTask converts a row to a global task.
func row2GlobeTask(r chunk.Row) *proto.Task {
task := &proto.Task{
ID: r.GetInt64(0),
Key: r.GetString(1),
Type: r.GetString(2),
DispatcherID: r.GetString(3),
State: r.GetString(4),
Meta: r.GetBytes(7),
Concurrency: uint64(r.GetInt64(8)),
Step: r.GetInt64(9),
Error: r.GetBytes(10),
}
// TODO: convert to local time.
task.StartTime, _ = r.GetTime(5).GoTime(time.UTC)
task.StateUpdateTime, _ = r.GetTime(6).GoTime(time.UTC)
return task
}
// WithNewSession executes the function with a new session.
func (stm *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error {
se, err := stm.sePool.Get()
if err != nil {
return err
}
defer stm.sePool.Put(se)
return fn(se.(sessionctx.Context))
}
func (stm *TaskManager) withNewTxn(fn func(se sessionctx.Context) error) error {
return stm.WithNewSession(func(se sessionctx.Context) (err error) {
_, err = execSQL(stm.ctx, se, "begin")
if err != nil {
return err
}
success := false
defer func() {
sql := "rollback"
if success {
sql = "commit"
}
_, commitErr := execSQL(stm.ctx, se, sql)
if err == nil && commitErr != nil {
err = commitErr
}
}()
if err = fn(se); err != nil {
return err
}
success = true
return nil
})
}
func (stm *TaskManager) executeSQLWithNewSession(ctx context.Context, sql string, args ...interface{}) (rs []chunk.Row, err error) {
err = stm.WithNewSession(func(se sessionctx.Context) error {
rs, err = execSQL(ctx, se, sql, args...)
return err
})
if err != nil {
return nil, err
}
return
}
// AddNewGlobalTask adds a new task to global task table.
func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta []byte) (taskID int64, err error) {
err = stm.WithNewSession(func(se sessionctx.Context) error {
_, err = execSQL(stm.ctx, se, "insert into mysql.tidb_global_task(task_key, type, state, concurrency, step, meta, state_update_time) values (%?, %?, %?, %?, %?, %?, %?)", key, tp, proto.TaskStatePending, concurrency, proto.StepInit, meta, time.Now().UTC().String())
if err != nil {
return err
}
rs, err := execSQL(stm.ctx, se, "select @@last_insert_id")
if err != nil {
return err
}
taskID, err = strconv.ParseInt(rs[0].GetString(0), 10, 64)
if err != nil {
return err
}
failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) })
return nil
})
if err != nil {
return 0, err
}
return
}
// GetNewGlobalTask get a new task from global task table, it's used by dispatcher only.
func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// GetGlobalTasksInStates gets the tasks in the states.
func (stm *TaskManager) GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) {
if len(states) == 0 {
return task, nil
}
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...)
if err != nil {
return task, err
}
for _, r := range rs {
task = append(task, row2GlobeTask(r))
}
return task, nil
}
// GetGlobalTaskByID gets the task by the global task ID.
func (stm *TaskManager) GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where id = %?", taskID)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// GetGlobalTaskByKey gets the task by the task key
func (stm *TaskManager) GetGlobalTaskByKey(key string) (task *proto.Task, err error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where task_key = %?", key)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// row2SubTask converts a row to a subtask.
func row2SubTask(r chunk.Row) *proto.Subtask {
task := &proto.Subtask{
ID: r.GetInt64(0),
Step: r.GetInt64(1),
Type: proto.Int2Type(int(r.GetInt64(5))),
SchedulerID: r.GetString(6),
State: r.GetString(8),
Meta: r.GetBytes(12),
StartTime: r.GetUint64(10),
}
tid, err := strconv.Atoi(r.GetString(3))
if err != nil {
logutil.BgLogger().Warn("unexpected task ID", zap.String("task ID", r.GetString(3)))
}
task.TaskID = int64(tid)
return task
}
// AddNewSubTask adds a new task to subtask table.
func (stm *TaskManager) AddNewSubTask(globalTaskID int64, step int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error {
st := proto.TaskStatePending
if isRevert {
st = proto.TaskStateRevertPending
}
_, err := stm.executeSQLWithNewSession(stm.ctx, "insert into mysql.tidb_background_subtask(task_key, step, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)", globalTaskID, step, designatedTiDBID, meta, st, proto.Type2Int(tp), []byte{})
if err != nil {
return err
}
return nil
}
// GetSubtaskInStates gets the subtask in the states.
func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, states ...interface{}) (*proto.Subtask, error) {
args := []interface{}{tidbID, taskID}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
return row2SubTask(rs[0]), nil
}
// GetSucceedSubtasksByStep gets the subtask in the success state.
func (stm *TaskManager) GetSucceedSubtasksByStep(taskID int64, step int64) ([]*proto.Subtask, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where task_key = %? and state = %? and step = %?", taskID, proto.TaskStateSucceed, step)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
subtasks := make([]*proto.Subtask, 0, len(rs))
for _, r := range rs {
subtasks = append(subtasks, row2SubTask(r))
}
return subtasks, nil
}
// GetSubtaskInStatesCnt gets the subtask count in the states.
func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) {
args := []interface{}{taskID}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select count(*) from mysql.tidb_background_subtask where task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
if err != nil {
return 0, err
}
return rs[0].GetInt64(0), nil
}
// CollectSubTaskError collects the subtask error.
func (stm *TaskManager) CollectSubTaskError(taskID int64) ([][]byte, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select error from mysql.tidb_background_subtask where task_key = %? AND state = %?", taskID, proto.TaskStateFailed)
if err != nil {
return nil, err
}
subTaskErrors := make([][]byte, 0, len(rs))
for _, err := range rs {
subTaskErrors = append(subTaskErrors, err.GetBytes(0))
}
return subTaskErrors, nil
}
// HasSubtasksInStates checks if there are subtasks in the states.
func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, states ...interface{}) (bool, error) {
args := []interface{}{tidbID, taskID}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where exec_id = %? and task_key = %? and state in ("+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}
// UpdateSubtaskStateAndError updates the subtask state.
func (stm *TaskManager) UpdateSubtaskStateAndError(id int64, state string, subTaskErr string) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set state = %?, error = %? where id = %?", state, subTaskErr, id)
return err
}
// FinishSubtask updates the subtask meta and mark state to succeed.
func (stm *TaskManager) FinishSubtask(id int64, meta []byte) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set meta = %?, state = %? where id = %?", meta, proto.TaskStateSucceed, id)
return err
}
// UpdateSubtaskHeartbeat updates the heartbeat of the subtask.
func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_background_subtask set exec_expired = %? where exec_id = %? and task_key = %?", heartbeat.String(), instanceID, taskID)
return err
}
// DeleteSubtasksByTaskID deletes the subtask of the given global task ID.
func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, "delete from mysql.tidb_background_subtask where task_key = %?", taskID)
if err != nil {
return err
}
return nil
}
// GetSchedulerIDsByTaskID gets the scheduler IDs of the given global task ID.
func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select distinct(exec_id) from mysql.tidb_background_subtask where task_key = %?", taskID)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
instanceIDs := make([]string, 0, len(rs))
for _, r := range rs {
id := r.GetString(0)
instanceIDs = append(instanceIDs, id)
}
return instanceIDs, nil
}
// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, isSubtaskRevert bool) error {
return stm.withNewTxn(func(se sessionctx.Context) error {
_, err := execSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %?",
gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, gTask.Error, gTask.ID)
if err != nil {
return err
}
failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) {
if val.(bool) {
failpoint.Return(errors.New("updateTaskErr"))
}
})
subtaskState := proto.TaskStatePending
if isSubtaskRevert {
subtaskState = proto.TaskStateRevertPending
}
for _, subtask := range subtasks {
// TODO: insert subtasks in batch
_, err = execSQL(stm.ctx, se, "insert into mysql.tidb_background_subtask(step, task_key, exec_id, meta, state, type, checkpoint) values (%?, %?, %?, %?, %?, %?, %?)",
gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{})
if err != nil {
return err
}
}
return nil
})
}
// CancelGlobalTask cancels global task
func (stm *TaskManager) CancelGlobalTask(taskID int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_global_task set state=%? where id=%? and state in (%?, %?)",
proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning,
)
return err
}
// IsGlobalTaskCancelling checks whether the task state is cancelling
func (stm *TaskManager) IsGlobalTaskCancelling(taskID int64) (bool, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?",
taskID, proto.TaskStateCancelling,
)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}