Files
tidb/disttask/framework/storage/task_table.go
2023-09-12 01:00:09 +00:00

750 lines
24 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"fmt"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/ngaut/pools"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/terror"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/sqlexec"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)
// SessionExecutor defines the interface for executing SQLs in a session.
type SessionExecutor interface {
// WithNewSession executes the function with a new session.
WithNewSession(fn func(se sessionctx.Context) error) error
// WithNewTxn executes the fn in a new transaction.
WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error
}
// TaskManager is the manager of global/sub task.
type TaskManager struct {
ctx context.Context
sePool sessionPool
}
type sessionPool interface {
Get() (pools.Resource, error)
Put(resource pools.Resource)
}
var _ SessionExecutor = &TaskManager{}
var taskManagerInstance atomic.Pointer[TaskManager]
var (
// TestLastTaskID is used for test to set the last task ID.
TestLastTaskID atomic.Int64
)
// NewTaskManager creates a new task manager.
func NewTaskManager(ctx context.Context, sePool sessionPool) *TaskManager {
ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask)
return &TaskManager{
ctx: ctx,
sePool: sePool,
}
}
// GetTaskManager gets the task manager.
func GetTaskManager() (*TaskManager, error) {
v := taskManagerInstance.Load()
if v == nil {
return nil, errors.New("global task manager is not initialized")
}
return v, nil
}
// SetTaskManager sets the task manager.
func SetTaskManager(is *TaskManager) {
taskManagerInstance.Store(is)
}
// ExecSQL executes the sql and returns the result.
// TODO: consider retry.
func ExecSQL(ctx context.Context, se sessionctx.Context, sql string, args ...interface{}) ([]chunk.Row, error) {
rs, err := se.(sqlexec.SQLExecutor).ExecuteInternal(ctx, sql, args...)
if err != nil {
return nil, err
}
if rs != nil {
defer terror.Call(rs.Close)
return sqlexec.DrainRecordSet(ctx, rs, 1024)
}
return nil, nil
}
// row2GlobeTask converts a row to a global task.
func row2GlobeTask(r chunk.Row) *proto.Task {
task := &proto.Task{
ID: r.GetInt64(0),
Key: r.GetString(1),
Type: r.GetString(2),
DispatcherID: r.GetString(3),
State: r.GetString(4),
Meta: r.GetBytes(7),
Concurrency: uint64(r.GetInt64(8)),
Step: r.GetInt64(9),
}
if !r.IsNull(10) {
errBytes := r.GetBytes(10)
stdErr := errors.Normalize("")
err := stdErr.UnmarshalJSON(errBytes)
if err != nil {
logutil.BgLogger().Error("unmarshal error", zap.Error(err))
task.Error = err
} else {
task.Error = stdErr
}
}
// TODO: convert to local time.
task.StartTime, _ = r.GetTime(5).GoTime(time.UTC)
task.StateUpdateTime, _ = r.GetTime(6).GoTime(time.UTC)
return task
}
// WithNewSession executes the function with a new session.
func (stm *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error {
se, err := stm.sePool.Get()
if err != nil {
return err
}
defer stm.sePool.Put(se)
return fn(se.(sessionctx.Context))
}
// WithNewTxn executes the fn in a new transaction.
func (stm *TaskManager) WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error {
ctx = util.WithInternalSourceType(ctx, kv.InternalDistTask)
return stm.WithNewSession(func(se sessionctx.Context) (err error) {
_, err = ExecSQL(ctx, se, "begin")
if err != nil {
return err
}
success := false
defer func() {
sql := "rollback"
if success {
sql = "commit"
}
_, commitErr := ExecSQL(ctx, se, sql)
if err == nil && commitErr != nil {
err = commitErr
}
}()
if err = fn(se); err != nil {
return err
}
success = true
return nil
})
}
func (stm *TaskManager) executeSQLWithNewSession(ctx context.Context, sql string, args ...interface{}) (rs []chunk.Row, err error) {
err = stm.WithNewSession(func(se sessionctx.Context) error {
rs, err = ExecSQL(ctx, se, sql, args...)
return err
})
if err != nil {
return nil, err
}
return
}
// AddNewGlobalTask adds a new task to global task table.
func (stm *TaskManager) AddNewGlobalTask(key, tp string, concurrency int, meta []byte) (taskID int64, err error) {
err = stm.WithNewSession(func(se sessionctx.Context) error {
var err2 error
taskID, err2 = stm.AddGlobalTaskWithSession(se, key, tp, concurrency, meta)
return err2
})
return
}
// AddGlobalTaskWithSession adds a new task to global task table with session.
func (stm *TaskManager) AddGlobalTaskWithSession(se sessionctx.Context, key, tp string, concurrency int, meta []byte) (taskID int64, err error) {
_, err = ExecSQL(stm.ctx, se,
`insert into mysql.tidb_global_task(task_key, type, state, concurrency, step, meta, state_update_time)
values (%?, %?, %?, %?, %?, %?, %?)`,
key, tp, proto.TaskStatePending, concurrency, proto.StepInit, meta, time.Now().UTC().String())
if err != nil {
return 0, err
}
rs, err := ExecSQL(stm.ctx, se, "select @@last_insert_id")
if err != nil {
return 0, err
}
taskID = int64(rs[0].GetUint64(0))
failpoint.Inject("testSetLastTaskID", func() { TestLastTaskID.Store(taskID) })
return taskID, nil
}
// GetNewGlobalTask get a new task from global task table, it's used by dispatcher only.
func (stm *TaskManager) GetNewGlobalTask() (task *proto.Task, err error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// GetGlobalTasksInStates gets the tasks in the states.
func (stm *TaskManager) GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error) {
if len(states) == 0 {
return task, nil
}
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...)
if err != nil {
return task, err
}
for _, r := range rs {
task = append(task, row2GlobeTask(r))
}
return task, nil
}
// GetGlobalTaskByID gets the task by the global task ID.
func (stm *TaskManager) GetGlobalTaskByID(taskID int64) (task *proto.Task, err error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where id = %?", taskID)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// GetGlobalTaskByKey gets the task by the task key
func (stm *TaskManager) GetGlobalTaskByKey(key string) (task *proto.Task, err error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select id, task_key, type, dispatcher_id, state, start_time, state_update_time, meta, concurrency, step, error from mysql.tidb_global_task where task_key = %?", key)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2GlobeTask(rs[0]), nil
}
// row2SubTask converts a row to a subtask.
func row2SubTask(r chunk.Row) *proto.Subtask {
// subtask defines start/update time as bigint, to ensure backward compatible,
// we keep it that way, and we convert it here.
var startTime, updateTime time.Time
if !r.IsNull(10) {
ts := r.GetInt64(10)
startTime = time.Unix(ts, 0)
}
if !r.IsNull(11) {
ts := r.GetInt64(11)
updateTime = time.Unix(ts, 0)
}
task := &proto.Subtask{
ID: r.GetInt64(0),
Step: r.GetInt64(1),
Type: proto.Int2Type(int(r.GetInt64(5))),
SchedulerID: r.GetString(6),
State: r.GetString(8),
Meta: r.GetBytes(12),
Summary: r.GetString(14),
StartTime: startTime,
UpdateTime: updateTime,
}
tid, err := strconv.Atoi(r.GetString(3))
if err != nil {
logutil.BgLogger().Warn("unexpected task ID", zap.String("task ID", r.GetString(3)))
}
task.TaskID = int64(tid)
return task
}
// AddNewSubTask adds a new task to subtask table.
func (stm *TaskManager) AddNewSubTask(globalTaskID int64, step int64, designatedTiDBID string, meta []byte, tp string, isRevert bool) error {
st := proto.TaskStatePending
if isRevert {
st = proto.TaskStateRevertPending
}
_, err := stm.executeSQLWithNewSession(stm.ctx, `insert into mysql.tidb_background_subtask
(task_key, step, exec_id, meta, state, type, checkpoint, summary)
values (%?, %?, %?, %?, %?, %?, %?, %?)`,
globalTaskID, step, designatedTiDBID, meta, st, proto.Type2Int(tp), []byte{}, "{}")
if err != nil {
return err
}
return nil
}
// GetSubtaskInStates gets the subtask in the states.
func (stm *TaskManager) GetSubtaskInStates(tidbID string, taskID int64, step int64, states ...interface{}) (*proto.Subtask, error) {
args := []interface{}{tidbID, taskID, step}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select * from mysql.tidb_background_subtask
where exec_id = %? and task_key = %? and step = %?
and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
return row2SubTask(rs[0]), nil
}
// UpdateErrorToSubtask updates the error to subtask.
func (stm *TaskManager) UpdateErrorToSubtask(tidbID string, taskID int64, err error) error {
if err == nil {
return nil
}
_, err1 := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
set state = %?, error = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where exec_id = %? and task_key = %? and state = %? limit 1;`,
proto.TaskStateFailed, serializeErr(err), tidbID, taskID, proto.TaskStatePending)
return err1
}
// PrintSubtaskInfo log the subtask info by taskKey.
func (stm *TaskManager) PrintSubtaskInfo(taskKey int64) {
rs, _ := stm.executeSQLWithNewSession(stm.ctx,
"select * from mysql.tidb_background_subtask where task_key = %?", taskKey)
for _, r := range rs {
errBytes := r.GetBytes(13)
var err error
if len(errBytes) > 0 {
stdErr := errors.Normalize("")
err1 := stdErr.UnmarshalJSON(errBytes)
if err1 != nil {
err = err1
} else {
err = stdErr
}
}
logutil.BgLogger().Info(fmt.Sprintf("subTask: %v\n", row2SubTask(r)), zap.Error(err))
}
}
// GetSucceedSubtasksByStep gets the subtask in the success state.
func (stm *TaskManager) GetSucceedSubtasksByStep(taskID int64, step int64) ([]*proto.Subtask, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select * from mysql.tidb_background_subtask
where task_key = %? and state = %? and step = %?`,
taskID, proto.TaskStateSucceed, step)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
subtasks := make([]*proto.Subtask, 0, len(rs))
for _, r := range rs {
subtasks = append(subtasks, row2SubTask(r))
}
return subtasks, nil
}
// GetSubtaskRowCount gets the subtask row count.
func (stm *TaskManager) GetSubtaskRowCount(taskID int64, step int64) (int64, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select
cast(sum(json_extract(summary, '$.row_count')) as signed) as row_count
from mysql.tidb_background_subtask where task_key = %? and step = %?`,
taskID, step)
if err != nil {
return 0, err
}
if len(rs) == 0 {
return 0, nil
}
return rs[0].GetInt64(0), nil
}
// UpdateSubtaskRowCount updates the subtask row count.
func (stm *TaskManager) UpdateSubtaskRowCount(subtaskID int64, rowCount int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
set summary = json_set(summary, '$.row_count', %?) where id = %?`,
rowCount, subtaskID)
return err
}
// GetSubtaskInStatesCnt gets the subtask count in the states.
func (stm *TaskManager) GetSubtaskInStatesCnt(taskID int64, states ...interface{}) (int64, error) {
args := []interface{}{taskID}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select count(*) from mysql.tidb_background_subtask
where task_key = %? and state in (`+strings.Repeat("%?,", len(states)-1)+"%?)", args...)
if err != nil {
return 0, err
}
return rs[0].GetInt64(0), nil
}
// CollectSubTaskError collects the subtask error.
func (stm *TaskManager) CollectSubTaskError(taskID int64) ([]error, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select error from mysql.tidb_background_subtask
where task_key = %? AND state = %?`, taskID, proto.TaskStateFailed)
if err != nil {
return nil, err
}
subTaskErrors := make([]error, 0, len(rs))
for _, row := range rs {
if row.IsNull(0) {
subTaskErrors = append(subTaskErrors, nil)
continue
}
errBytes := row.GetBytes(0)
if len(errBytes) == 0 {
subTaskErrors = append(subTaskErrors, nil)
continue
}
stdErr := errors.Normalize("")
err := stdErr.UnmarshalJSON(errBytes)
if err != nil {
return nil, err
}
subTaskErrors = append(subTaskErrors, stdErr)
}
return subTaskErrors, nil
}
// HasSubtasksInStates checks if there are subtasks in the states.
func (stm *TaskManager) HasSubtasksInStates(tidbID string, taskID int64, step int64, states ...interface{}) (bool, error) {
args := []interface{}{tidbID, taskID, step}
args = append(args, states...)
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select 1 from mysql.tidb_background_subtask
where exec_id = %? and task_key = %? and step = %?
and state in (`+strings.Repeat("%?,", len(states)-1)+"%?) limit 1", args...)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}
// StartSubtask updates the subtask state to running.
func (stm *TaskManager) StartSubtask(subtaskID int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where id = %?`,
proto.TaskStateRunning, subtaskID)
return err
}
// StartManager insert the manager information into dist_framework_meta.
func (stm *TaskManager) StartManager(tidbID string, role string) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `insert into mysql.dist_framework_meta values(%?, %?, DEFAULT)
on duplicate key update role = %?`, tidbID, role, role)
return err
}
// UpdateSubtaskStateAndError updates the subtask state.
func (stm *TaskManager) UpdateSubtaskStateAndError(id int64, state string, subTaskErr error) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
set state = %?, error = %?, state_update_time = unix_timestamp() where id = %?`,
state, serializeErr(subTaskErr), id)
return err
}
// FinishSubtask updates the subtask meta and mark state to succeed.
func (stm *TaskManager) FinishSubtask(id int64, meta []byte) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
set meta = %?, state = %?, state_update_time = unix_timestamp() where id = %?`,
meta, proto.TaskStateSucceed, id)
return err
}
// UpdateSubtaskHeartbeat updates the heartbeat of the subtask.
// not used now.
// TODO: not sure whether we really need this method, don't update state_update_time now,
func (stm *TaskManager) UpdateSubtaskHeartbeat(instanceID string, taskID int64, heartbeat time.Time) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `update mysql.tidb_background_subtask
set exec_expired = %? where exec_id = %? and task_key = %?`,
heartbeat.String(), instanceID, taskID)
return err
}
// DeleteSubtasksByTaskID deletes the subtask of the given global task ID.
func (stm *TaskManager) DeleteSubtasksByTaskID(taskID int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, `delete from mysql.tidb_background_subtask
where task_key = %?`, taskID)
if err != nil {
return err
}
return nil
}
// GetSchedulerIDsByTaskID gets the scheduler IDs of the given global task ID.
func (stm *TaskManager) GetSchedulerIDsByTaskID(taskID int64) ([]string, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select distinct(exec_id) from mysql.tidb_background_subtask
where task_key = %?`, taskID)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
instanceIDs := make([]string, 0, len(rs))
for _, r := range rs {
id := r.GetString(0)
instanceIDs = append(instanceIDs, id)
}
return instanceIDs, nil
}
// GetSchedulerIDsByTaskIDAndStep gets the scheduler IDs of the given global task ID and step.
func (stm *TaskManager) GetSchedulerIDsByTaskIDAndStep(taskID int64, step int64) ([]string, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, `select distinct(exec_id) from mysql.tidb_background_subtask
where task_key = %? and step = %?`, taskID, step)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
instanceIDs := make([]string, 0, len(rs))
for _, r := range rs {
id := r.GetString(0)
instanceIDs = append(instanceIDs, id)
}
return instanceIDs, nil
}
// IsSchedulerCanceled checks if subtask 'execID' of task 'taskID' has been canceled somehow.
func (stm *TaskManager) IsSchedulerCanceled(taskID int64, execID string) (bool, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_background_subtask where task_key = %? and exec_id = %?", taskID, execID)
if err != nil {
return false, err
}
return len(rs) == 0, nil
}
// UpdateFailedSchedulerIDs replace failed scheduler nodes with alive nodes.
func (stm *TaskManager) UpdateFailedSchedulerIDs(taskID int64, replaceNodes map[string]string) error {
// skip
if len(replaceNodes) == 0 {
return nil
}
sql := new(strings.Builder)
if err := sqlexec.FormatSQL(sql, "update mysql.tidb_background_subtask set state = %? ,exec_id = (case ", proto.TaskStatePending); err != nil {
return err
}
for k, v := range replaceNodes {
if err := sqlexec.FormatSQL(sql, "when exec_id = %? then %? ", k, v); err != nil {
return err
}
}
if err := sqlexec.FormatSQL(sql, " end) where task_key = %? and state != \"succeed\" and exec_id in (", taskID); err != nil {
return err
}
i := 0
for k := range replaceNodes {
if i != 0 {
if err := sqlexec.FormatSQL(sql, ","); err != nil {
return err
}
}
if err := sqlexec.FormatSQL(sql, "%?", k); err != nil {
return err
}
i++
}
if err := sqlexec.FormatSQL(sql, ")"); err != nil {
return err
}
_, err := stm.executeSQLWithNewSession(stm.ctx, sql.String())
return err
}
// AddSubTasks add new batch of subtasks.
func (stm *TaskManager) AddSubTasks(task *proto.Task, subtasks []*proto.Subtask) error {
err := stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error {
for _, subtask := range subtasks {
subtaskState := proto.TaskStatePending
if task.State == proto.TaskStateReverting {
subtaskState = proto.TaskStateRevertPending
}
_, err := ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask
(step, task_key, exec_id, meta, state, type, checkpoint, summary)
values (%?, %?, %?, %?, %?, %?, %?, %?)`,
task.Step, task.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
if err != nil {
return err
}
}
return nil
})
return err
}
// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) (bool, error) {
retryable := true
err := stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error {
_, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state = %?, dispatcher_id = %?, step = %?, state_update_time = %?, concurrency = %?, meta = %?, error = %? where id = %? and state = %?",
gTask.State, gTask.DispatcherID, gTask.Step, gTask.StateUpdateTime.UTC().String(), gTask.Concurrency, gTask.Meta, serializeErr(gTask.Error), gTask.ID, prevState)
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() == 0 {
retryable = false
return errors.New("invalid task state transform, state already changed")
}
failpoint.Inject("MockUpdateTaskErr", func(val failpoint.Value) {
if val.(bool) {
failpoint.Return(errors.New("updateTaskErr"))
}
})
subtaskState := proto.TaskStatePending
if gTask.State == proto.TaskStateReverting {
subtaskState = proto.TaskStateRevertPending
}
for _, subtask := range subtasks {
// TODO: insert subtasks in batch
_, err = ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask
(step, task_key, exec_id, meta, state, type, checkpoint, summary)
values (%?, %?, %?, %?, %?, %?, %?, %?)`,
gTask.Step, gTask.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
if err != nil {
return err
}
}
return nil
})
return retryable, err
}
func serializeErr(err error) []byte {
if err == nil {
return nil
}
originErr := errors.Cause(err)
tErr, ok := originErr.(*errors.Error)
if !ok {
tErr = errors.Normalize(originErr.Error())
}
errBytes, err := tErr.MarshalJSON()
if err != nil {
return nil
}
return errBytes
}
// CancelGlobalTask cancels global task.
func (stm *TaskManager) CancelGlobalTask(taskID int64) error {
_, err := stm.executeSQLWithNewSession(stm.ctx, "update mysql.tidb_global_task set state=%? where id=%? and state in (%?, %?)",
proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning,
)
return err
}
// CancelGlobalTaskByKeySession cancels global task by key using input session.
func (stm *TaskManager) CancelGlobalTaskByKeySession(se sessionctx.Context, taskKey string) error {
_, err := ExecSQL(stm.ctx, se, "update mysql.tidb_global_task set state=%? where task_key=%? and state in (%?, %?)",
proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning)
return err
}
// IsGlobalTaskCancelling checks whether the task state is cancelling.
func (stm *TaskManager) IsGlobalTaskCancelling(taskID int64) (bool, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?",
taskID, proto.TaskStateCancelling,
)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}
// GetSubtasksByStep gets subtasks of global task by step.
func (stm *TaskManager) GetSubtasksByStep(taskID, step int64) ([]*proto.Subtask, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx,
"select * from mysql.tidb_background_subtask where task_key = %? and step = %?",
taskID, step)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
subtasks := make([]*proto.Subtask, 0, len(rs))
for _, r := range rs {
subtasks = append(subtasks, row2SubTask(r))
}
return subtasks, nil
}
// GetNodesByRole gets nodes map from dist_framework_meta by role.
func (stm *TaskManager) GetNodesByRole(role string) (map[string]bool, error) {
rs, err := stm.executeSQLWithNewSession(stm.ctx,
"select host from mysql.dist_framework_meta where role = %?", role)
if err != nil {
return nil, err
}
nodes := make(map[string]bool, len(rs))
for _, r := range rs {
nodes[r.GetString(0)] = true
}
return nodes, nil
}