disttask: refine taskTable and remove useless code (#50461)

ref pingcap/tidb#48795
This commit is contained in:
EasonBall
2024-01-17 12:06:15 +08:00
committed by GitHub
parent 0c584297fe
commit 595fa7affc
19 changed files with 739 additions and 726 deletions

View File

@ -60,7 +60,7 @@ func TestFrameworkPauseAndResume(t *testing.T) {
mgr, err := storage.GetTaskManager()
require.NoError(t, err)
errs, err := mgr.CollectSubTaskError(ctx, 1)
errs, err := mgr.GetSubtaskErrors(ctx, 1)
require.NoError(t, err)
require.Empty(t, errs)
@ -77,7 +77,7 @@ func TestFrameworkPauseAndResume(t *testing.T) {
CheckSubtasksState(ctx, t, 1, proto.SubtaskStateSucceed, 4)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/syncAfterResume"))
errs, err = mgr.CollectSubTaskError(ctx, 1)
errs, err = mgr.GetSubtaskErrors(ctx, 1)
require.NoError(t, err)
require.Empty(t, errs)
distContext.Close()

View File

@ -251,21 +251,6 @@ func (mr *MockTaskManagerMockRecorder) CancelTask(arg0, arg1 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelTask", reflect.TypeOf((*MockTaskManager)(nil).CancelTask), arg0, arg1)
}
// CollectSubTaskError mocks base method.
func (m *MockTaskManager) CollectSubTaskError(arg0 context.Context, arg1 int64) ([]error, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CollectSubTaskError", arg0, arg1)
ret0, _ := ret[0].([]error)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CollectSubTaskError indicates an expected call of CollectSubTaskError.
func (mr *MockTaskManagerMockRecorder) CollectSubTaskError(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CollectSubTaskError", reflect.TypeOf((*MockTaskManager)(nil).CollectSubTaskError), arg0, arg1)
}
// DeleteDeadNodes mocks base method.
func (m *MockTaskManager) DeleteDeadNodes(arg0 context.Context, arg1 []string) error {
m.ctrl.T.Helper()
@ -338,6 +323,21 @@ func (mr *MockTaskManagerMockRecorder) GetAllNodes(arg0 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllNodes", reflect.TypeOf((*MockTaskManager)(nil).GetAllNodes), arg0)
}
// GetAllSubtasksByStepAndState mocks base method.
func (m *MockTaskManager) GetAllSubtasksByStepAndState(arg0 context.Context, arg1 int64, arg2 proto.Step, arg3 proto.SubtaskState) ([]*proto.Subtask, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAllSubtasksByStepAndState", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]*proto.Subtask)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAllSubtasksByStepAndState indicates an expected call of GetAllSubtasksByStepAndState.
func (mr *MockTaskManagerMockRecorder) GetAllSubtasksByStepAndState(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllSubtasksByStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetAllSubtasksByStepAndState), arg0, arg1, arg2, arg3)
}
// GetManagedNodes mocks base method.
func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]proto.ManagedNode, error) {
m.ctrl.T.Helper()
@ -368,34 +368,19 @@ func (mr *MockTaskManagerMockRecorder) GetSubtaskCntGroupByStates(arg0, arg1, ar
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskCntGroupByStates", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskCntGroupByStates), arg0, arg1, arg2)
}
// GetSubtasksByExecIdsAndStepAndState mocks base method.
func (m *MockTaskManager) GetSubtasksByExecIdsAndStepAndState(arg0 context.Context, arg1 []string, arg2 int64, arg3 proto.Step, arg4 proto.SubtaskState) ([]*proto.Subtask, error) {
// GetSubtaskErrors mocks base method.
func (m *MockTaskManager) GetSubtaskErrors(arg0 context.Context, arg1 int64) ([]error, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSubtasksByExecIdsAndStepAndState", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].([]*proto.Subtask)
ret := m.ctrl.Call(m, "GetSubtaskErrors", arg0, arg1)
ret0, _ := ret[0].([]error)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSubtasksByExecIdsAndStepAndState indicates an expected call of GetSubtasksByExecIdsAndStepAndState.
func (mr *MockTaskManagerMockRecorder) GetSubtasksByExecIdsAndStepAndState(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call {
// GetSubtaskErrors indicates an expected call of GetSubtaskErrors.
func (mr *MockTaskManagerMockRecorder) GetSubtaskErrors(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByExecIdsAndStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByExecIdsAndStepAndState), arg0, arg1, arg2, arg3, arg4)
}
// GetSubtasksByStepAndState mocks base method.
func (m *MockTaskManager) GetSubtasksByStepAndState(arg0 context.Context, arg1 int64, arg2 proto.Step, arg3 proto.TaskState) ([]*proto.Subtask, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetSubtasksByStepAndState", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].([]*proto.Subtask)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSubtasksByStepAndState indicates an expected call of GetSubtasksByStepAndState.
func (mr *MockTaskManagerMockRecorder) GetSubtasksByStepAndState(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByStepAndState), arg0, arg1, arg2, arg3)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskErrors", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskErrors), arg0, arg1)
}
// GetTaskByID mocks base method.
@ -428,21 +413,6 @@ func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskID(arg0, arg1 any
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskID), arg0, arg1)
}
// GetTaskExecutorIDsByTaskIDAndStep mocks base method.
func (m *MockTaskManager) GetTaskExecutorIDsByTaskIDAndStep(arg0 context.Context, arg1 int64, arg2 proto.Step) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskIDAndStep", arg0, arg1, arg2)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetTaskExecutorIDsByTaskIDAndStep indicates an expected call of GetTaskExecutorIDsByTaskIDAndStep.
func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskIDAndStep(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskIDAndStep", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskIDAndStep), arg0, arg1, arg2)
}
// GetTasksInStates mocks base method.
func (m *MockTaskManager) GetTasksInStates(arg0 context.Context, arg1 ...any) ([]*proto.Task, error) {
m.ctrl.T.Helper()

View File

@ -102,24 +102,24 @@ func (mr *MockTaskTableMockRecorder) GetFirstSubtaskInStates(arg0, arg1, arg2, a
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFirstSubtaskInStates", reflect.TypeOf((*MockTaskTable)(nil).GetFirstSubtaskInStates), varargs...)
}
// GetSubtasksByStepAndStates mocks base method.
func (m *MockTaskTable) GetSubtasksByStepAndStates(arg0 context.Context, arg1 string, arg2 int64, arg3 proto.Step, arg4 ...proto.SubtaskState) ([]*proto.Subtask, error) {
// GetSubtasksByExecIDAndStepAndStates mocks base method.
func (m *MockTaskTable) GetSubtasksByExecIDAndStepAndStates(arg0 context.Context, arg1 string, arg2 int64, arg3 proto.Step, arg4 ...proto.SubtaskState) ([]*proto.Subtask, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1, arg2, arg3}
for _, a := range arg4 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GetSubtasksByStepAndStates", varargs...)
ret := m.ctrl.Call(m, "GetSubtasksByExecIDAndStepAndStates", varargs...)
ret0, _ := ret[0].([]*proto.Subtask)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetSubtasksByStepAndStates indicates an expected call of GetSubtasksByStepAndStates.
func (mr *MockTaskTableMockRecorder) GetSubtasksByStepAndStates(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call {
// GetSubtasksByExecIDAndStepAndStates indicates an expected call of GetSubtasksByExecIDAndStepAndStates.
func (mr *MockTaskTableMockRecorder) GetSubtasksByExecIDAndStepAndStates(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByStepAndStates", reflect.TypeOf((*MockTaskTable)(nil).GetSubtasksByStepAndStates), varargs...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByExecIDAndStepAndStates", reflect.TypeOf((*MockTaskTable)(nil).GetSubtasksByExecIDAndStepAndStates), varargs...)
}
// GetTaskByID mocks base method.

View File

@ -73,17 +73,18 @@ type TaskManager interface {
// GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state.
GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error)
ResumeSubtasks(ctx context.Context, taskID int64) error
CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error)
GetSubtaskErrors(ctx context.Context, taskID int64) ([]error, error)
UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error
// GetManagedNodes returns the nodes managed by dist framework and can be used
// to execute tasks. If there are any nodes with background role, we use them,
// else we use nodes without role.
// returned nodes are sorted by node id(host:port).
GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error)
// GetTaskExecutorIDsByTaskID gets the task executor IDs of the given task ID.
GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error)
GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error)
GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error)
GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error)
// GetAllSubtasksByStepAndState gets all subtasks by given states for one step.
GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error)
WithNewSession(fn func(se sessionctx.Context) error) error
WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error

View File

@ -339,7 +339,7 @@ func (s *BaseScheduler) onRunning() error {
return err
}
if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 {
subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, task.ID)
subTaskErrs, err := s.taskMgr.GetSubtaskErrors(s.ctx, task.ID)
if err != nil {
logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err))
return err
@ -586,7 +586,7 @@ func (s *BaseScheduler) GetAllTaskExecutorIDs(ctx context.Context, task *proto.T
// GetPreviousSubtaskMetas get subtask metas from specific step.
func (s *BaseScheduler) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) {
previousSubtasks, err := s.taskMgr.GetSubtasksByStepAndState(s.ctx, taskID, step, proto.TaskStateSucceed)
previousSubtasks, err := s.taskMgr.GetAllSubtasksByStepAndState(s.ctx, taskID, step, proto.SubtaskStateSucceed)
if err != nil {
logutil.Logger(s.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", int64(step)))
return nil, err

View File

@ -22,6 +22,7 @@ import (
"github.com/ngaut/pools"
"github.com/pingcap/tidb/pkg/disttask/framework/mock"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/disttask/framework/testutil"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/util"
@ -75,7 +76,7 @@ func TestCleanUpRoutine(t *testing.T) {
}
sch.DoCleanUpRoutine()
require.Eventually(t, func() bool {
tasks, err := mgr.GetTasksFromHistoryInStates(ctx, proto.TaskStateSucceed)
tasks, err := testutil.GetTasksFromHistoryInStates(ctx, mgr, proto.TaskStateSucceed)
require.NoError(t, err)
return len(tasks) != 0
}, 5*time.Second*10, time.Millisecond*300)

View File

@ -341,7 +341,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel,
if len(tasks) == taskCnt {
break
}
historyTasks, err := mgr.GetTasksFromHistoryInStates(ctx, expectedState)
historyTasks, err := testutil.GetTasksFromHistoryInStates(ctx, mgr, expectedState)
require.NoError(t, err)
if len(tasks)+len(historyTasks) == taskCnt {
break

View File

@ -3,6 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
go_library(
name = "storage",
srcs = [
"converter.go",
"history.go",
"nodes.go",
"subtask_state.go",
"task_state.go",
"task_table.go",
],

View File

@ -0,0 +1,115 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"strconv"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)
func row2TaskBasic(r chunk.Row) *proto.Task {
task := &proto.Task{
ID: r.GetInt64(0),
Key: r.GetString(1),
Type: proto.TaskType(r.GetString(2)),
State: proto.TaskState(r.GetString(3)),
Step: proto.Step(r.GetInt64(4)),
Priority: int(r.GetInt64(5)),
Concurrency: int(r.GetInt64(6)),
}
task.CreateTime, _ = r.GetTime(7).GoTime(time.Local)
return task
}
// Row2Task converts a row to a task.
func Row2Task(r chunk.Row) *proto.Task {
task := row2TaskBasic(r)
var startTime, updateTime time.Time
if !r.IsNull(8) {
startTime, _ = r.GetTime(8).GoTime(time.Local)
}
if !r.IsNull(9) {
updateTime, _ = r.GetTime(9).GoTime(time.Local)
}
task.StartTime = startTime
task.StateUpdateTime = updateTime
task.Meta = r.GetBytes(10)
task.SchedulerID = r.GetString(11)
if !r.IsNull(12) {
errBytes := r.GetBytes(12)
stdErr := errors.Normalize("")
err := stdErr.UnmarshalJSON(errBytes)
if err != nil {
logutil.BgLogger().Error("unmarshal task error", zap.Error(err))
task.Error = errors.New(string(errBytes))
} else {
task.Error = stdErr
}
}
return task
}
// row2BasicSubTask converts a row to a subtask with basic info
func row2BasicSubTask(r chunk.Row) *proto.Subtask {
taskIDStr := r.GetString(2)
tid, err := strconv.Atoi(taskIDStr)
if err != nil {
logutil.BgLogger().Warn("unexpected subtask id", zap.String("subtask-id", taskIDStr))
}
createTime, _ := r.GetTime(7).GoTime(time.Local)
var ordinal int
if !r.IsNull(8) {
ordinal = int(r.GetInt64(8))
}
subtask := &proto.Subtask{
ID: r.GetInt64(0),
Step: proto.Step(r.GetInt64(1)),
TaskID: int64(tid),
Type: proto.Int2Type(int(r.GetInt64(3))),
ExecID: r.GetString(4),
State: proto.SubtaskState(r.GetString(5)),
Concurrency: int(r.GetInt64(6)),
CreateTime: createTime,
Ordinal: ordinal,
}
return subtask
}
// Row2SubTask converts a row to a subtask.
func Row2SubTask(r chunk.Row) *proto.Subtask {
subtask := row2BasicSubTask(r)
// subtask defines start/update time as bigint, to ensure backward compatible,
// we keep it that way, and we convert it here.
var startTime, updateTime time.Time
if !r.IsNull(9) {
ts := r.GetInt64(9)
startTime = time.Unix(ts, 0)
}
if !r.IsNull(10) {
ts := r.GetInt64(10)
updateTime = time.Unix(ts, 0)
}
subtask.StartTime = startTime
subtask.UpdateTime = updateTime
subtask.Meta = r.GetBytes(11)
subtask.Summary = r.GetJSON(12).String()
return subtask
}

View File

@ -0,0 +1,94 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"fmt"
"strings"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util/sqlexec"
)
// TransferSubtasks2HistoryWithSession transfer the selected subtasks into tidb_background_subtask_history table by taskID.
func (*TaskManager) TransferSubtasks2HistoryWithSession(ctx context.Context, se sessionctx.Context, taskID int64) error {
_, err := sqlexec.ExecSQL(ctx, se, `insert into mysql.tidb_background_subtask_history select * from mysql.tidb_background_subtask where task_key = %?`, taskID)
if err != nil {
return err
}
// delete taskID subtask
_, err = sqlexec.ExecSQL(ctx, se, "delete from mysql.tidb_background_subtask where task_key = %?", taskID)
return err
}
// TransferTasks2History transfer the selected tasks into tidb_global_task_history table by taskIDs.
func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*proto.Task) error {
if len(tasks) == 0 {
return nil
}
taskIDStrs := make([]string, 0, len(tasks))
for _, task := range tasks {
taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID))
}
return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
// sensitive data in meta might be redacted, need update first.
for _, t := range tasks {
_, err := sqlexec.ExecSQL(ctx, se, `
update mysql.tidb_global_task
set meta= %?, state_update_time = CURRENT_TIMESTAMP()
where id = %?`, t.Meta, t.ID)
if err != nil {
return err
}
}
_, err := sqlexec.ExecSQL(ctx, se, `
insert into mysql.tidb_global_task_history
select * from mysql.tidb_global_task
where id in(`+strings.Join(taskIDStrs, `, `)+`)`)
if err != nil {
return err
}
_, err = sqlexec.ExecSQL(ctx, se, `
delete from mysql.tidb_global_task
where id in(`+strings.Join(taskIDStrs, `, `)+`)`)
for _, t := range tasks {
err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID)
if err != nil {
return err
}
}
return err
})
}
// GCSubtasks deletes the history subtask which is older than the given days.
func (mgr *TaskManager) GCSubtasks(ctx context.Context) error {
subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60
failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) {
if val, ok := val.(int); ok {
subtaskHistoryKeepSeconds = val
}
})
_, err := mgr.ExecuteSQLWithNewSession(
ctx,
fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds),
)
return err
}

View File

@ -0,0 +1,203 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"fmt"
"strings"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util/cpu"
"github.com/pingcap/tidb/pkg/util/sqlescape"
"github.com/pingcap/tidb/pkg/util/sqlexec"
)
// InitMeta insert the manager information into dist_framework_meta.
func (mgr *TaskManager) InitMeta(ctx context.Context, tidbID string, role string) error {
return mgr.WithNewSession(func(se sessionctx.Context) error {
return mgr.InitMetaSession(ctx, se, tidbID, role)
})
}
// InitMetaSession insert the manager information into dist_framework_meta.
// if the record exists, update the cpu_count and role.
func (*TaskManager) InitMetaSession(ctx context.Context, se sessionctx.Context, execID string, role string) error {
cpuCount := cpu.GetCPUCount()
_, err := sqlexec.ExecSQL(ctx, se, `
insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id)
values (%?, %?, %?, -1)
on duplicate key
update cpu_count = %?, role = %?`,
execID, role, cpuCount, cpuCount, role)
return err
}
// RecoverMeta insert the manager information into dist_framework_meta.
// if the record exists, update the cpu_count.
// Don't update role for we only update it in `set global tidb_service_scope`.
// if not there might has a data race.
func (mgr *TaskManager) RecoverMeta(ctx context.Context, execID string, role string) error {
cpuCount := cpu.GetCPUCount()
_, err := mgr.ExecuteSQLWithNewSession(ctx, `
insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id)
values (%?, %?, %?, -1)
on duplicate key
update cpu_count = %?`,
execID, role, cpuCount, cpuCount)
return err
}
// DeleteDeadNodes deletes the dead nodes from mysql.dist_framework_meta.
func (mgr *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) error {
if len(nodes) == 0 {
return nil
}
return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
deleteSQL := new(strings.Builder)
if err := sqlescape.FormatSQL(deleteSQL, "delete from mysql.dist_framework_meta where host in("); err != nil {
return err
}
deleteElems := make([]string, 0, len(nodes))
for _, node := range nodes {
deleteElems = append(deleteElems, fmt.Sprintf(`"%s"`, node))
}
deleteSQL.WriteString(strings.Join(deleteElems, ", "))
deleteSQL.WriteString(")")
_, err := sqlexec.ExecSQL(ctx, se, deleteSQL.String())
return err
})
}
// GetManagedNodes implements scheduler.TaskManager interface.
func (mgr *TaskManager) GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) {
var nodes []proto.ManagedNode
err := mgr.WithNewSession(func(se sessionctx.Context) error {
var err2 error
nodes, err2 = mgr.getManagedNodesWithSession(ctx, se)
return err2
})
return nodes, err
}
func (mgr *TaskManager) getManagedNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) {
nodes, err := mgr.getAllNodesWithSession(ctx, se)
if err != nil {
return nil, err
}
nodeMap := make(map[string][]proto.ManagedNode, 2)
for _, node := range nodes {
nodeMap[node.Role] = append(nodeMap[node.Role], node)
}
if len(nodeMap["background"]) == 0 {
return nodeMap[""], nil
}
return nodeMap["background"], nil
}
// GetAllNodes gets nodes in dist_framework_meta.
func (mgr *TaskManager) GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) {
var nodes []proto.ManagedNode
err := mgr.WithNewSession(func(se sessionctx.Context) error {
var err2 error
nodes, err2 = mgr.getAllNodesWithSession(ctx, se)
return err2
})
return nodes, err
}
func (*TaskManager) getAllNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) {
rs, err := sqlexec.ExecSQL(ctx, se, `
select host, role, cpu_count
from mysql.dist_framework_meta
order by host`)
if err != nil {
return nil, err
}
nodes := make([]proto.ManagedNode, 0, len(rs))
for _, r := range rs {
nodes = append(nodes, proto.ManagedNode{
ID: r.GetString(0),
Role: r.GetString(1),
CPUCount: int(r.GetInt64(2)),
})
}
return nodes, nil
}
// GetUsedSlotsOnNodes implements the scheduler.TaskManager interface.
func (mgr *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) {
// concurrency of subtasks of some step is the same, we use max(concurrency)
// to make group by works.
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `
select
exec_id, sum(concurrency)
from (
select exec_id, task_key, max(concurrency) concurrency
from mysql.tidb_background_subtask
where state in (%?, %?)
group by exec_id, task_key
) a
group by exec_id`,
proto.SubtaskStatePending, proto.SubtaskStateRunning,
)
if err != nil {
return nil, err
}
slots := make(map[string]int, len(rs))
for _, r := range rs {
val, _ := r.GetMyDecimal(1).ToInt()
slots[r.GetString(0)] = int(val)
}
return slots, nil
}
// GetCPUCountOfManagedNode gets the cpu count of managed node.
func (mgr *TaskManager) GetCPUCountOfManagedNode(ctx context.Context) (int, error) {
var cnt int
err := mgr.WithNewSession(func(se sessionctx.Context) error {
var err2 error
cnt, err2 = mgr.getCPUCountOfManagedNode(ctx, se)
return err2
})
return cnt, err
}
// getCPUCountOfManagedNode gets the cpu count of managed node.
// returns error when there's no managed node or no node has valid cpu count.
func (mgr *TaskManager) getCPUCountOfManagedNode(ctx context.Context, se sessionctx.Context) (int, error) {
nodes, err := mgr.getManagedNodesWithSession(ctx, se)
if err != nil {
return 0, err
}
if len(nodes) == 0 {
return 0, errors.New("no managed nodes")
}
var cpuCount int
for _, n := range nodes {
if n.CPUCount > 0 {
cpuCount = n.CPUCount
break
}
}
if cpuCount == 0 {
return 0, errors.New("no managed node have enough resource for dist task")
}
return cpuCount, nil
}

View File

@ -0,0 +1,147 @@
// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/pingcap/tidb/pkg/disttask/framework/proto"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util/sqlexec"
)
// StartSubtask updates the subtask state to running.
func (mgr *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execID string) error {
err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
vars := se.GetSessionVars()
_, err := sqlexec.ExecSQL(ctx,
se,
`update mysql.tidb_background_subtask
set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where id = %? and exec_id = %?`,
proto.SubtaskStateRunning,
subtaskID,
execID)
if err != nil {
return err
}
if vars.StmtCtx.AffectedRows() == 0 {
return ErrSubtaskNotFound
}
return nil
})
return err
}
// FinishSubtask updates the subtask meta and mark state to succeed.
func (mgr *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set meta = %?, state = %?, state_update_time = unix_timestamp(), end_time = CURRENT_TIMESTAMP()
where id = %? and exec_id = %?`,
meta, proto.SubtaskStateSucceed, id, execID)
return err
}
// FailSubtask update the task's subtask state to failed and set the err.
func (mgr *TaskManager) FailSubtask(ctx context.Context, execID string, taskID int64, err error) error {
if err == nil {
return nil
}
_, err1 := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask
set state = %?,
error = %?,
start_time = unix_timestamp(),
state_update_time = unix_timestamp(),
end_time = CURRENT_TIMESTAMP()
where exec_id = %? and
task_key = %? and
state in (%?, %?)
limit 1;`,
proto.SubtaskStateFailed,
serializeErr(err),
execID,
taskID,
proto.SubtaskStatePending,
proto.SubtaskStateRunning)
return err1
}
// CancelSubtask update the task's subtasks' state to canceled.
func (mgr *TaskManager) CancelSubtask(ctx context.Context, execID string, taskID int64) error {
_, err1 := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask
set state = %?,
start_time = unix_timestamp(),
state_update_time = unix_timestamp(),
end_time = CURRENT_TIMESTAMP()
where exec_id = %? and
task_key = %? and
state in (%?, %?)
limit 1;`,
proto.SubtaskStateCanceled,
execID,
taskID,
proto.SubtaskStatePending,
proto.SubtaskStateRunning)
return err1
}
// PauseSubtasks update all running/pending subtasks to pasued state.
func (mgr *TaskManager) PauseSubtasks(ctx context.Context, execID string, taskID int64) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, execID)
return err
}
// ResumeSubtasks update all paused subtasks to pending state.
func (mgr *TaskManager) ResumeSubtasks(ctx context.Context, taskID int64) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask set state = "pending", error = null where task_key = %? and state = "paused"`, taskID)
return err
}
// RunningSubtasksBack2Pending implements the taskexecutor.TaskTable interface.
func (mgr *TaskManager) RunningSubtasksBack2Pending(ctx context.Context, subtasks []*proto.Subtask) error {
// skip the update process.
if len(subtasks) == 0 {
return nil
}
err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
for _, subtask := range subtasks {
_, err := sqlexec.ExecSQL(ctx, se, `
update mysql.tidb_background_subtask
set state = %?, state_update_time = CURRENT_TIMESTAMP()
where id = %? and exec_id = %? and state = %?`,
proto.SubtaskStatePending, subtask.ID, subtask.ExecID, proto.SubtaskStateRunning)
if err != nil {
return err
}
}
return nil
})
return err
}
// UpdateSubtaskStateAndError updates the subtask state.
func (mgr *TaskManager) UpdateSubtaskStateAndError(
ctx context.Context,
execID string,
id int64, state proto.SubtaskState, subTaskErr error) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set state = %?, error = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`,
state, serializeErr(subTaskErr), id, execID)
return err
}

View File

@ -63,7 +63,7 @@ func TestTaskTable(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int64(1), id)
task, err := gm.GetOneTask(ctx)
task, err := testutil.GetOneTask(ctx, gm)
require.NoError(t, err)
require.Equal(t, int64(1), task.ID)
require.Equal(t, "key1", task.Key)
@ -117,12 +117,12 @@ func TestTaskTable(t *testing.T) {
id, err = gm.CreateTask(ctx, "key2", "test", 4, []byte("test"))
require.NoError(t, err)
cancelling, err := gm.IsTaskCancelling(ctx, id)
cancelling, err := testutil.IsTaskCancelling(ctx, gm, id)
require.NoError(t, err)
require.False(t, cancelling)
require.NoError(t, gm.CancelTask(ctx, id))
cancelling, err = gm.IsTaskCancelling(ctx, id)
cancelling, err = testutil.IsTaskCancelling(ctx, gm, id)
require.NoError(t, err)
require.True(t, cancelling)
@ -212,7 +212,7 @@ func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, s
checkTaskStateStep(t, task, proto.TaskStateRunning, step)
require.GreaterOrEqual(t, task.StartTime, startTime)
require.GreaterOrEqual(t, task.StateUpdateTime, startTime)
gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task.ID, task.Step, proto.TaskStatePending)
gotSubtasks, err := tm.GetAllSubtasksByStepAndState(ctx, task.ID, task.Step, proto.SubtaskStatePending)
require.NoError(t, err)
require.Len(t, gotSubtasks, len(subtasks))
sort.Slice(gotSubtasks, func(i, j int) bool {
@ -329,7 +329,7 @@ func TestSwitchTaskStepInBatch(t *testing.T) {
task2, err = tm.GetTaskByID(ctx, task2.ID)
require.NoError(t, err)
checkTaskStateStep(t, task2, proto.TaskStatePending, proto.StepInit)
gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task2.ID, proto.StepOne, proto.TaskStatePending)
gotSubtasks, err := tm.GetAllSubtasksByStepAndState(ctx, task2.ID, proto.StepOne, proto.SubtaskStatePending)
require.NoError(t, err)
require.Len(t, gotSubtasks, 1)
// run again, should success
@ -541,8 +541,7 @@ func TestSubTaskTable(t *testing.T) {
ts := time.Now()
time.Sleep(time.Second)
err = sm.StartSubtask(ctx, 1, "tidb1")
require.NoError(t, err)
require.NoError(t, sm.StartSubtask(ctx, 1, "tidb1"))
err = sm.StartSubtask(ctx, 1, "tidb2")
require.Error(t, storage.ErrSubtaskNotFound, err)
@ -577,8 +576,7 @@ func TestSubTaskTable(t *testing.T) {
require.NoError(t, err)
require.False(t, ok)
err = sm.DeleteSubtasksByTaskID(ctx, 1)
require.NoError(t, err)
require.NoError(t, testutil.DeleteSubtasksByTaskID(ctx, sm, 1))
ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending, proto.SubtaskStateRunning)
require.NoError(t, err)
@ -590,22 +588,20 @@ func TestSubTaskTable(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRevertPending])
subtasks, err := sm.GetSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.TaskStateSucceed)
subtasks, err := sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.SubtaskStateSucceed)
require.NoError(t, err)
require.Len(t, subtasks, 0)
err = sm.FinishSubtask(ctx, "tidb1", 2, []byte{})
require.NoError(t, err)
require.NoError(t, sm.FinishSubtask(ctx, "tidb1", 2, []byte{}))
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.TaskStateSucceed)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.SubtaskStateSucceed)
require.NoError(t, err)
require.Len(t, subtasks, 1)
rowCount, err := sm.GetSubtaskRowCount(ctx, 2, proto.StepInit)
require.NoError(t, err)
require.Equal(t, int64(0), rowCount)
err = sm.UpdateSubtaskRowCount(ctx, 2, 100)
require.NoError(t, err)
require.NoError(t, sm.UpdateSubtaskRowCount(ctx, 2, 100))
rowCount, err = sm.GetSubtaskRowCount(ctx, 2, proto.StepInit)
require.NoError(t, err)
require.Equal(t, int64(100), rowCount)
@ -613,62 +609,42 @@ func TestSubTaskTable(t *testing.T) {
// test UpdateSubtasksExecIDs
// 1. update one subtask
testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false)
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
subtasks[0].ExecID = "tidb2"
require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks))
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, "tidb2", subtasks[0].ExecID)
// 2. update 2 subtasks
testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false)
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
subtasks[0].ExecID = "tidb3"
require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks))
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, "tidb3", subtasks[0].ExecID)
require.Equal(t, "tidb1", subtasks[1].ExecID)
// update fail
require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateRunning, nil))
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, "tidb3", subtasks[0].ExecID)
subtasks[0].ExecID = "tidb2"
// update success
require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks))
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, "tidb2", subtasks[0].ExecID)
// test GetSubtasksByExecIdsAndStepAndState
testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false)
testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false)
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 6, proto.StepInit, proto.TaskStatePending)
require.NoError(t, err)
require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateRunning, nil))
subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.SubtaskStateRunning)
require.NoError(t, err)
require.Equal(t, 1, len(subtasks))
subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, 1, len(subtasks))
testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb2", []byte("test"), proto.TaskTypeExample, 11, false)
subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1", "tidb2"}, 6, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, 2, len(subtasks))
subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, 1, len(subtasks))
// test CollectSubTaskError
// test GetSubtaskErrors
testutil.CreateSubTask(t, sm, 7, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false)
subtasks, err = sm.GetSubtasksByStepAndState(ctx, 7, proto.StepInit, proto.TaskStatePending)
subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 7, proto.StepInit, proto.SubtaskStatePending)
require.NoError(t, err)
require.Equal(t, 1, len(subtasks))
require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateFailed, errors.New("test err")))
subtaskErrs, err := sm.CollectSubTaskError(ctx, 7)
subtaskErrs, err := sm.GetSubtaskErrors(ctx, 7)
require.NoError(t, err)
require.Equal(t, 1, len(subtaskErrs))
require.ErrorContains(t, subtaskErrs[0], "test err")
@ -681,7 +657,7 @@ func TestBothTaskAndSubTaskTable(t *testing.T) {
require.NoError(t, err)
require.Equal(t, int64(1), id)
task, err := sm.GetOneTask(ctx)
task, err := testutil.GetOneTask(ctx, sm)
require.NoError(t, err)
require.Equal(t, proto.TaskStatePending, task.State)
@ -769,7 +745,7 @@ func TestBothTaskAndSubTaskTable(t *testing.T) {
require.Equal(t, int64(2), cntByStates[proto.SubtaskStateRevertPending])
// test transactional
require.NoError(t, sm.DeleteSubtasksByTaskID(ctx, 1))
require.NoError(t, testutil.DeleteSubtasksByTaskID(ctx, sm, 1))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)"))
defer func() {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr"))

View File

@ -16,11 +16,9 @@ package storage
import (
"context"
"fmt"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/docker/go-units"
"github.com/ngaut/pools"
@ -31,21 +29,19 @@ import (
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/cpu"
"github.com/pingcap/tidb/pkg/util/intest"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/sqlescape"
"github.com/pingcap/tidb/pkg/util/sqlexec"
"github.com/tikv/client-go/v2/util"
"go.uber.org/zap"
)
const (
defaultSubtaskKeepDays = 14
basicTaskColumns = `id, task_key, type, state, step, priority, concurrency, create_time`
// TaskColumns is the columns for task.
// TODO: dispatcher_id will update to scheduler_id later
taskColumns = basicTaskColumns + `, start_time, state_update_time, meta, dispatcher_id, error`
TaskColumns = basicTaskColumns + `, start_time, state_update_time, meta, dispatcher_id, error`
// InsertTaskColumns is the columns used in insert task.
InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time`
basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal`
@ -132,48 +128,6 @@ func SetTaskManager(is *TaskManager) {
taskManagerInstance.Store(is)
}
func row2TaskBasic(r chunk.Row) *proto.Task {
task := &proto.Task{
ID: r.GetInt64(0),
Key: r.GetString(1),
Type: proto.TaskType(r.GetString(2)),
State: proto.TaskState(r.GetString(3)),
Step: proto.Step(r.GetInt64(4)),
Priority: int(r.GetInt64(5)),
Concurrency: int(r.GetInt64(6)),
}
task.CreateTime, _ = r.GetTime(7).GoTime(time.Local)
return task
}
// row2Task converts a row to a task.
func row2Task(r chunk.Row) *proto.Task {
task := row2TaskBasic(r)
var startTime, updateTime time.Time
if !r.IsNull(8) {
startTime, _ = r.GetTime(8).GoTime(time.Local)
}
if !r.IsNull(9) {
updateTime, _ = r.GetTime(9).GoTime(time.Local)
}
task.StartTime = startTime
task.StateUpdateTime = updateTime
task.Meta = r.GetBytes(10)
task.SchedulerID = r.GetString(11)
if !r.IsNull(12) {
errBytes := r.GetBytes(12)
stdErr := errors.Normalize("")
err := stdErr.UnmarshalJSON(errBytes)
if err != nil {
logutil.BgLogger().Error("unmarshal task error", zap.Error(err))
task.Error = errors.New(string(errBytes))
} else {
task.Error = stdErr
}
}
return task
}
// WithNewSession executes the function with a new session.
func (mgr *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error {
se, err := mgr.sePool.Get()
@ -266,20 +220,6 @@ func (mgr *TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx
return taskID, nil
}
// GetOneTask get a task from task table, it's used by scheduler only.
func (mgr *TaskManager) GetOneTask(ctx context.Context) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return row2Task(rs[0]), nil
}
// GetTopUnfinishedTasks implements the scheduler.TaskManager interface.
func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) (task []*proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx,
@ -312,7 +252,7 @@ func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...interfac
}
rs, err := mgr.ExecuteSQLWithNewSession(ctx,
"select "+taskColumns+" from mysql.tidb_global_task "+
"select "+TaskColumns+" from mysql.tidb_global_task "+
"where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)"+
" order by priority asc, create_time asc, id asc", states...)
if err != nil {
@ -320,31 +260,14 @@ func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...interfac
}
for _, r := range rs {
task = append(task, row2Task(r))
}
return task, nil
}
// GetTasksFromHistoryInStates gets the tasks in history table in the states.
func (mgr *TaskManager) GetTasksFromHistoryInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) {
if len(states) == 0 {
return task, nil
}
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task_history where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...)
if err != nil {
return task, err
}
for _, r := range rs {
task = append(task, row2Task(r))
task = append(task, Row2Task(r))
}
return task, nil
}
// GetTaskByID gets the task by the task ID.
func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where id = %?", taskID)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where id = %?", taskID)
if err != nil {
return task, err
}
@ -352,13 +275,13 @@ func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *pr
return nil, ErrTaskNotFound
}
return row2Task(rs[0]), nil
return Row2Task(rs[0]), nil
}
// GetTaskByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history.
func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where id = %? "+
"union select "+taskColumns+" from mysql.tidb_global_task_history where id = %?", taskID, taskID)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where id = %? "+
"union select "+TaskColumns+" from mysql.tidb_global_task_history where id = %?", taskID, taskID)
if err != nil {
return task, err
}
@ -366,12 +289,12 @@ func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64
return nil, ErrTaskNotFound
}
return row2Task(rs[0]), nil
return Row2Task(rs[0]), nil
}
// GetTaskByKey gets the task by the task key.
func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where task_key = %?", key)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where task_key = %?", key)
if err != nil {
return task, err
}
@ -379,13 +302,13 @@ func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *pro
return nil, ErrTaskNotFound
}
return row2Task(rs[0]), nil
return Row2Task(rs[0]), nil
}
// GetTaskByKeyWithHistory gets the task from history table by the task key.
func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where task_key = %?"+
"union select "+taskColumns+" from mysql.tidb_global_task_history where task_key = %?", key, key)
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where task_key = %?"+
"union select "+TaskColumns+" from mysql.tidb_global_task_history where task_key = %?", key, key)
if err != nil {
return task, err
}
@ -393,87 +316,12 @@ func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string)
return nil, ErrTaskNotFound
}
return row2Task(rs[0]), nil
return Row2Task(rs[0]), nil
}
// GetUsedSlotsOnNodes implements the scheduler.TaskManager interface.
func (mgr *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) {
// concurrency of subtasks of some step is the same, we use max(concurrency)
// to make group by works.
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `
select
exec_id, sum(concurrency)
from (
select exec_id, task_key, max(concurrency) concurrency
from mysql.tidb_background_subtask
where state in (%?, %?)
group by exec_id, task_key
) a
group by exec_id`,
proto.TaskStatePending, proto.TaskStateRunning,
)
if err != nil {
return nil, err
}
slots := make(map[string]int, len(rs))
for _, r := range rs {
val, _ := r.GetMyDecimal(1).ToInt()
slots[r.GetString(0)] = int(val)
}
return slots, nil
}
// row2BasicSubTask converts a row to a subtask with basic info
func row2BasicSubTask(r chunk.Row) *proto.Subtask {
taskIDStr := r.GetString(2)
tid, err := strconv.Atoi(taskIDStr)
if err != nil {
logutil.BgLogger().Warn("unexpected subtask id", zap.String("subtask-id", taskIDStr))
}
createTime, _ := r.GetTime(7).GoTime(time.Local)
var ordinal int
if !r.IsNull(8) {
ordinal = int(r.GetInt64(8))
}
subtask := &proto.Subtask{
ID: r.GetInt64(0),
Step: proto.Step(r.GetInt64(1)),
TaskID: int64(tid),
Type: proto.Int2Type(int(r.GetInt64(3))),
ExecID: r.GetString(4),
State: proto.SubtaskState(r.GetString(5)),
Concurrency: int(r.GetInt64(6)),
CreateTime: createTime,
Ordinal: ordinal,
}
return subtask
}
// Row2SubTask converts a row to a subtask.
func Row2SubTask(r chunk.Row) *proto.Subtask {
subtask := row2BasicSubTask(r)
// subtask defines start/update time as bigint, to ensure backward compatible,
// we keep it that way, and we convert it here.
var startTime, updateTime time.Time
if !r.IsNull(9) {
ts := r.GetInt64(9)
startTime = time.Unix(ts, 0)
}
if !r.IsNull(10) {
ts := r.GetInt64(10)
updateTime = time.Unix(ts, 0)
}
subtask.StartTime = startTime
subtask.UpdateTime = updateTime
subtask.Meta = r.GetBytes(11)
subtask.Summary = r.GetJSON(12).String()
return subtask
}
// GetSubtasksByStepAndStates gets all subtasks by given states.
func (mgr *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) {
args := []interface{}{tidbID, taskID, step}
// GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states on one node.
func (mgr *TaskManager) GetSubtasksByExecIDAndStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) {
args := []interface{}{execID, taskID, step}
for _, state := range states {
args = append(args, state)
}
@ -491,26 +339,6 @@ func (mgr *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID s
return subtasks, nil
}
// GetSubtasksByExecIdsAndStepAndState gets all subtasks by given taskID, exec_id, step and state.
func (mgr *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, execIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) {
args := []interface{}{taskID, step, state}
for _, execID := range execIDs {
args = append(args, execID)
}
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask
where task_key = %? and step = %? and state = %?
and exec_id in (`+strings.Repeat("%?,", len(execIDs)-1)+"%?)", args...)
if err != nil {
return nil, err
}
subtasks := make([]*proto.Subtask, len(rs))
for i, row := range rs {
subtasks[i] = Row2SubTask(row)
}
return subtasks, nil
}
// GetFirstSubtaskInStates gets the first subtask by given states.
func (mgr *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (*proto.Subtask, error) {
args := []interface{}{tidbID, taskID, step}
@ -530,51 +358,6 @@ func (mgr *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID stri
return Row2SubTask(rs[0]), nil
}
// FailSubtask update the task's subtask state to failed and set the err.
func (mgr *TaskManager) FailSubtask(ctx context.Context, execID string, taskID int64, err error) error {
if err == nil {
return nil
}
_, err1 := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask
set state = %?,
error = %?,
start_time = unix_timestamp(),
state_update_time = unix_timestamp(),
end_time = CURRENT_TIMESTAMP()
where exec_id = %? and
task_key = %? and
state in (%?, %?)
limit 1;`,
proto.SubtaskStateFailed,
serializeErr(err),
execID,
taskID,
proto.SubtaskStatePending,
proto.SubtaskStateRunning)
return err1
}
// CancelSubtask update the task's subtasks' state to canceled.
func (mgr *TaskManager) CancelSubtask(ctx context.Context, execID string, taskID int64) error {
_, err1 := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask
set state = %?,
start_time = unix_timestamp(),
state_update_time = unix_timestamp(),
end_time = CURRENT_TIMESTAMP()
where exec_id = %? and
task_key = %? and
state in (%?, %?)
limit 1;`,
proto.SubtaskStateCanceled,
execID,
taskID,
proto.SubtaskStatePending,
proto.SubtaskStateRunning)
return err1
}
// GetActiveSubtasks implements TaskManager.GetActiveSubtasks.
func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([]*proto.Subtask, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `
@ -591,8 +374,8 @@ func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([]
return subtasks, nil
}
// GetSubtasksByStepAndState gets the subtask by step and state.
func (mgr *TaskManager) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) {
// GetAllSubtasksByStepAndState gets the subtask by step and state.
func (mgr *TaskManager) GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask
where task_key = %? and state = %? and step = %?`,
taskID, state, step)
@ -654,8 +437,8 @@ func (mgr *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID i
return res, nil
}
// CollectSubTaskError collects the subtask error.
func (mgr *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) {
// GetSubtaskErrors gets subtasks' errors.
func (mgr *TaskManager) GetSubtaskErrors(ctx context.Context, taskID int64) ([]error, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx,
`select error from mysql.tidb_background_subtask
where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled)
@ -700,95 +483,6 @@ func (mgr *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string,
return len(rs) > 0, nil
}
// StartSubtask updates the subtask state to running.
func (mgr *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execID string) error {
err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
vars := se.GetSessionVars()
_, err := sqlexec.ExecSQL(ctx,
se,
`update mysql.tidb_background_subtask
set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp()
where id = %? and exec_id = %?`,
proto.TaskStateRunning,
subtaskID,
execID)
if err != nil {
return err
}
if vars.StmtCtx.AffectedRows() == 0 {
return ErrSubtaskNotFound
}
return nil
})
return err
}
// InitMeta insert the manager information into dist_framework_meta.
func (mgr *TaskManager) InitMeta(ctx context.Context, tidbID string, role string) error {
return mgr.WithNewSession(func(se sessionctx.Context) error {
return mgr.InitMetaSession(ctx, se, tidbID, role)
})
}
// InitMetaSession insert the manager information into dist_framework_meta.
// if the record exists, update the cpu_count and role.
func (*TaskManager) InitMetaSession(ctx context.Context, se sessionctx.Context, execID string, role string) error {
cpuCount := cpu.GetCPUCount()
_, err := sqlexec.ExecSQL(ctx, se, `
insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id)
values (%?, %?, %?, -1)
on duplicate key
update cpu_count = %?, role = %?`,
execID, role, cpuCount, cpuCount, role)
return err
}
// RecoverMeta insert the manager information into dist_framework_meta.
// if the record exists, update the cpu_count.
// Don't update role for we only update it in `set global tidb_service_scope`.
// if not there might has a data race.
func (mgr *TaskManager) RecoverMeta(ctx context.Context, execID string, role string) error {
cpuCount := cpu.GetCPUCount()
_, err := mgr.ExecuteSQLWithNewSession(ctx, `
insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id)
values (%?, %?, %?, -1)
on duplicate key
update cpu_count = %?`,
execID, role, cpuCount, cpuCount)
return err
}
// UpdateSubtaskStateAndError updates the subtask state.
func (mgr *TaskManager) UpdateSubtaskStateAndError(
ctx context.Context,
execID string,
id int64, state proto.SubtaskState, subTaskErr error) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set state = %?, error = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`,
state, serializeErr(subTaskErr), id, execID)
return err
}
// FinishSubtask updates the subtask meta and mark state to succeed.
func (mgr *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask
set meta = %?, state = %?, state_update_time = unix_timestamp(), end_time = CURRENT_TIMESTAMP()
where id = %? and exec_id = %?`,
meta, proto.TaskStateSucceed, id, execID)
return err
}
// DeleteSubtasksByTaskID deletes the subtask of the given task ID.
func (mgr *TaskManager) DeleteSubtasksByTaskID(ctx context.Context, taskID int64) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `delete from mysql.tidb_background_subtask
where task_key = %?`, taskID)
if err != nil {
return err
}
return nil
}
// GetTaskExecutorIDsByTaskID gets the task executor IDs of the given task ID.
func (mgr *TaskManager) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select distinct(exec_id) from mysql.tidb_background_subtask
@ -809,26 +503,6 @@ func (mgr *TaskManager) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID i
return instanceIDs, nil
}
// GetTaskExecutorIDsByTaskIDAndStep gets the task executor IDs of the given global task ID and step.
func (mgr *TaskManager) GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select distinct(exec_id) from mysql.tidb_background_subtask
where task_key = %? and step = %?`, taskID, step)
if err != nil {
return nil, err
}
if len(rs) == 0 {
return nil, nil
}
instanceIDs := make([]string, 0, len(rs))
for _, r := range rs {
id := r.GetString(0)
instanceIDs = append(instanceIDs, id)
}
return instanceIDs, nil
}
// UpdateSubtasksExecIDs update subtasks' execID.
func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error {
// skip the update process.
@ -851,64 +525,6 @@ func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*p
return err
}
// RunningSubtasksBack2Pending implements the taskexecutor.TaskTable interface.
func (mgr *TaskManager) RunningSubtasksBack2Pending(ctx context.Context, subtasks []*proto.Subtask) error {
// skip the update process.
if len(subtasks) == 0 {
return nil
}
err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
for _, subtask := range subtasks {
_, err := sqlexec.ExecSQL(ctx, se, `
update mysql.tidb_background_subtask
set state = %?, state_update_time = CURRENT_TIMESTAMP()
where id = %? and exec_id = %? and state = %?`,
proto.SubtaskStatePending, subtask.ID, subtask.ExecID, proto.SubtaskStateRunning)
if err != nil {
return err
}
}
return nil
})
return err
}
// DeleteDeadNodes deletes the dead nodes from mysql.dist_framework_meta.
func (mgr *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) error {
if len(nodes) == 0 {
return nil
}
return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
deleteSQL := new(strings.Builder)
if err := sqlescape.FormatSQL(deleteSQL, "delete from mysql.dist_framework_meta where host in("); err != nil {
return err
}
deleteElems := make([]string, 0, len(nodes))
for _, node := range nodes {
deleteElems = append(deleteElems, fmt.Sprintf(`"%s"`, node))
}
deleteSQL.WriteString(strings.Join(deleteElems, ", "))
deleteSQL.WriteString(")")
_, err := sqlexec.ExecSQL(ctx, se, deleteSQL.String())
return err
})
}
// PauseSubtasks update all running/pending subtasks to pasued state.
func (mgr *TaskManager) PauseSubtasks(ctx context.Context, execID string, taskID int64) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, execID)
return err
}
// ResumeSubtasks update all paused subtasks to pending state.
func (mgr *TaskManager) ResumeSubtasks(ctx context.Context, taskID int64) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_background_subtask set state = "pending", error = null where task_key = %? and state = "paused"`, taskID)
return err
}
// SwitchTaskStep implements the dispatcher.TaskManager interface.
func (mgr *TaskManager) SwitchTaskStep(
ctx context.Context,
@ -983,7 +599,7 @@ func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, s
for _, subtask := range subtasks {
markerList = append(markerList, "(%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')")
args = append(args, subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta,
proto.TaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal)
proto.SubtaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal)
}
sb.WriteString(strings.Join(markerList, ","))
_, err := sqlexec.ExecSQL(ctx, se, sb.String(), args...)
@ -1135,19 +751,6 @@ func serializeErr(err error) []byte {
return errBytes
}
// IsTaskCancelling checks whether the task state is cancelling.
func (mgr *TaskManager) IsTaskCancelling(ctx context.Context, taskID int64) (bool, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?",
taskID, proto.TaskStateCancelling,
)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}
// GetSubtasksWithHistory gets the subtasks from tidb_global_task and tidb_global_task_history.
func (mgr *TaskManager) GetSubtasksWithHistory(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error) {
var (
@ -1189,161 +792,3 @@ func (mgr *TaskManager) GetSubtasksWithHistory(ctx context.Context, taskID int64
}
return subtasks, nil
}
// TransferSubtasks2HistoryWithSession transfer the selected subtasks into tidb_background_subtask_history table by taskID.
func (*TaskManager) TransferSubtasks2HistoryWithSession(ctx context.Context, se sessionctx.Context, taskID int64) error {
_, err := sqlexec.ExecSQL(ctx, se, `insert into mysql.tidb_background_subtask_history select * from mysql.tidb_background_subtask where task_key = %?`, taskID)
if err != nil {
return err
}
// delete taskID subtask
_, err = sqlexec.ExecSQL(ctx, se, "delete from mysql.tidb_background_subtask where task_key = %?", taskID)
return err
}
// TransferTasks2History transfer the selected tasks into tidb_global_task_history table by taskIDs.
func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*proto.Task) error {
if len(tasks) == 0 {
return nil
}
taskIDStrs := make([]string, 0, len(tasks))
for _, task := range tasks {
taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID))
}
return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
// sensitive data in meta might be redacted, need update first.
for _, t := range tasks {
_, err := sqlexec.ExecSQL(ctx, se, `
update mysql.tidb_global_task
set meta= %?, state_update_time = CURRENT_TIMESTAMP()
where id = %?`, t.Meta, t.ID)
if err != nil {
return err
}
}
_, err := sqlexec.ExecSQL(ctx, se, `
insert into mysql.tidb_global_task_history
select * from mysql.tidb_global_task
where id in(`+strings.Join(taskIDStrs, `, `)+`)`)
if err != nil {
return err
}
_, err = sqlexec.ExecSQL(ctx, se, `
delete from mysql.tidb_global_task
where id in(`+strings.Join(taskIDStrs, `, `)+`)`)
for _, t := range tasks {
err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID)
if err != nil {
return err
}
}
return err
})
}
// GCSubtasks deletes the history subtask which is older than the given days.
func (mgr *TaskManager) GCSubtasks(ctx context.Context) error {
subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60
failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) {
if val, ok := val.(int); ok {
subtaskHistoryKeepSeconds = val
}
})
_, err := mgr.ExecuteSQLWithNewSession(
ctx,
fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds),
)
return err
}
// GetManagedNodes implements scheduler.TaskManager interface.
func (mgr *TaskManager) GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) {
var nodes []proto.ManagedNode
err := mgr.WithNewSession(func(se sessionctx.Context) error {
var err2 error
nodes, err2 = mgr.getManagedNodesWithSession(ctx, se)
return err2
})
return nodes, err
}
func (mgr *TaskManager) getManagedNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) {
nodes, err := mgr.getAllNodesWithSession(ctx, se)
if err != nil {
return nil, err
}
nodeMap := make(map[string][]proto.ManagedNode, 2)
for _, node := range nodes {
nodeMap[node.Role] = append(nodeMap[node.Role], node)
}
if len(nodeMap["background"]) == 0 {
return nodeMap[""], nil
}
return nodeMap["background"], nil
}
// GetAllNodes gets nodes in dist_framework_meta.
func (mgr *TaskManager) GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) {
var nodes []proto.ManagedNode
err := mgr.WithNewSession(func(se sessionctx.Context) error {
var err2 error
nodes, err2 = mgr.getAllNodesWithSession(ctx, se)
return err2
})
return nodes, err
}
func (*TaskManager) getAllNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) {
rs, err := sqlexec.ExecSQL(ctx, se, `
select host, role, cpu_count
from mysql.dist_framework_meta
order by host`)
if err != nil {
return nil, err
}
nodes := make([]proto.ManagedNode, 0, len(rs))
for _, r := range rs {
nodes = append(nodes, proto.ManagedNode{
ID: r.GetString(0),
Role: r.GetString(1),
CPUCount: int(r.GetInt64(2)),
})
}
return nodes, nil
}
// GetCPUCountOfManagedNode gets the cpu count of managed node.
func (mgr *TaskManager) GetCPUCountOfManagedNode(ctx context.Context) (int, error) {
var cnt int
err := mgr.WithNewSession(func(se sessionctx.Context) error {
var err2 error
cnt, err2 = mgr.getCPUCountOfManagedNode(ctx, se)
return err2
})
return cnt, err
}
// getCPUCountOfManagedNode gets the cpu count of managed node.
// returns error when there's no managed node or no node has valid cpu count.
func (mgr *TaskManager) getCPUCountOfManagedNode(ctx context.Context, se sessionctx.Context) (int, error) {
nodes, err := mgr.getManagedNodesWithSession(ctx, se)
if err != nil {
return 0, err
}
if len(nodes) == 0 {
return 0, errors.New("no managed nodes")
}
var cpuCount int
for _, n := range nodes {
if n.CPUCount > 0 {
cpuCount = n.CPUCount
break
}
}
if cpuCount == 0 {
return 0, errors.New("no managed node have enough resource for dist task")
}
return cpuCount, nil
}

View File

@ -25,7 +25,8 @@ import (
type TaskTable interface {
GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error)
GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error)
GetSubtasksByStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error)
// GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states and execID.
GetSubtasksByExecIDAndStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error)
GetFirstSubtaskInStates(ctx context.Context, instanceID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (*proto.Subtask, error)
// InitMeta insert the manager information into dist_framework_meta.
// Call it when starting task executor or in set variable operation.

View File

@ -111,7 +111,7 @@ func (s *BaseTaskExecutor) checkBalanceSubtask(ctx context.Context) {
}
task := s.task.Load()
subtasks, err := s.taskTable.GetSubtasksByStepAndStates(ctx, s.id, task.ID, task.Step,
subtasks, err := s.taskTable.GetSubtasksByExecIDAndStepAndStates(ctx, s.id, task.ID, task.Step,
proto.SubtaskStateRunning)
if err != nil {
logutil.Logger(s.logCtx).Error("get subtasks failed", zap.Error(err))
@ -231,7 +231,7 @@ func (s *BaseTaskExecutor) runStep(ctx context.Context, task *proto.Task) (resEr
}
}()
subtasks, err := s.taskTable.GetSubtasksByStepAndStates(
subtasks, err := s.taskTable.GetSubtasksByExecIDAndStepAndStates(
runCtx, s.id, task.ID, task.Step,
proto.SubtaskStatePending, proto.SubtaskStateRunning)
if err != nil {

View File

@ -56,7 +56,7 @@ func TestTaskExecutorRun(t *testing.T) {
mockExtension.EXPECT().IsRetryableError(gomock.Any()).AnyTimes()
// mock for checkBalanceSubtask
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id",
taskID, proto.StepOne, proto.SubtaskStateRunning).Return([]*proto.Subtask{{ID: 1}}, nil).AnyTimes()
task1 := &proto.Task{Step: proto.StepOne, Type: tp}
@ -85,7 +85,7 @@ func TestTaskExecutorRun(t *testing.T) {
// 3. run subtask failed
runSubtaskErr := errors.New("run subtask error")
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
@ -102,7 +102,7 @@ func TestTaskExecutorRun(t *testing.T) {
// 4. run subtask success
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
@ -122,7 +122,7 @@ func TestTaskExecutorRun(t *testing.T) {
// 5. run subtask one by one
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(
[]*proto.Subtask{
{ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"},
@ -157,7 +157,7 @@ func TestTaskExecutorRun(t *testing.T) {
// idempotent, so fail it.
subtaskID := int64(2)
theSubtask := &proto.Subtask{ID: subtaskID, Type: tp, Step: proto.StepOne, State: proto.SubtaskStateRunning, ExecID: "id"}
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{theSubtask}, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
@ -172,7 +172,7 @@ func TestTaskExecutorRun(t *testing.T) {
// run previous left subtask in running state again, but the subtask idempotent,
// run it again.
theSubtask = &proto.Subtask{ID: subtaskID, Type: tp, Step: proto.StepOne, State: proto.SubtaskStateRunning, ExecID: "id"}
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{theSubtask}, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
// first round of the run loop
@ -192,7 +192,7 @@ func TestTaskExecutorRun(t *testing.T) {
require.True(t, ctrl.Satisfied())
// 6. cancel
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
@ -208,7 +208,7 @@ func TestTaskExecutorRun(t *testing.T) {
require.True(t, ctrl.Satisfied())
// 7. RunSubtask return context.Canceled
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
@ -223,7 +223,7 @@ func TestTaskExecutorRun(t *testing.T) {
require.True(t, ctrl.Satisfied())
// 8. grpc cancel
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
@ -239,7 +239,7 @@ func TestTaskExecutorRun(t *testing.T) {
require.True(t, ctrl.Satisfied())
// 9. annotate grpc cancel
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
@ -261,7 +261,7 @@ func TestTaskExecutorRun(t *testing.T) {
// 10. subtask owned by other executor
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{
ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
@ -378,7 +378,7 @@ func TestTaskExecutor(t *testing.T) {
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil).AnyTimes()
mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false).AnyTimes()
// mock for checkBalanceSubtask
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id",
taskID, proto.StepOne, proto.SubtaskStateRunning).Return([]*proto.Subtask{{ID: 1}}, nil).AnyTimes()
task := &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}
@ -391,7 +391,7 @@ func TestTaskExecutor(t *testing.T) {
subtasks := []*proto.Subtask{
{ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"},
}
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks[0], nil)
@ -416,7 +416,7 @@ func TestTaskExecutor(t *testing.T) {
// 3. run one subtask, then task moved to history(ErrTaskNotFound).
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks[0], nil)
@ -449,12 +449,12 @@ func TestRunStepCurrentSubtaskScheduledAway(t *testing.T) {
taskExecutor.Extension = mockExtension
// mock for checkBalanceSubtask
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1",
task.ID, proto.StepOne, proto.SubtaskStateRunning).Return([]*proto.Subtask{}, nil)
// mock for runStep
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil)
mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", task.ID, proto.StepOne,
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks, nil)
mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "tidb1", task.ID, proto.StepOne,
unfinishedNormalSubtaskStates...).Return(subtasks[0], nil)
@ -492,9 +492,9 @@ func TestCheckBalanceSubtask(t *testing.T) {
checkBalanceSubtaskInterval = 100 * time.Millisecond
// subtask scheduled away
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1",
task.ID, task.Step, proto.SubtaskStateRunning).Return(nil, errors.New("error"))
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1",
task.ID, task.Step, proto.SubtaskStateRunning).Return([]*proto.Subtask{}, nil)
runCtx, cancelCause := context.WithCancelCause(ctx)
taskExecutor.registerCancelFunc(cancelCause)
@ -505,7 +505,7 @@ func TestCheckBalanceSubtask(t *testing.T) {
subtasks := []*proto.Subtask{{ID: 1, ExecID: "tidb1"}}
// in-idempotent running subtask
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1",
task.ID, task.Step, proto.SubtaskStateRunning).Return(subtasks, nil)
mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(gomock.Any(), "tidb1",
subtasks[0].ID, proto.SubtaskStateFailed, ErrNonIdempotentSubtask).Return(nil)
@ -517,12 +517,12 @@ func TestCheckBalanceSubtask(t *testing.T) {
require.Zero(t, taskExecutor.currSubtaskID.Load())
taskExecutor.currSubtaskID.Store(1)
subtasks = []*proto.Subtask{{ID: 1, ExecID: "tidb1"}, {ID: 2, ExecID: "tidb1"}}
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1",
task.ID, task.Step, proto.SubtaskStateRunning).Return(subtasks, nil)
mockExtension.EXPECT().IsIdempotent(gomock.Any()).Return(true)
mockSubtaskTable.EXPECT().RunningSubtasksBack2Pending(gomock.Any(), []*proto.Subtask{{ID: 2, ExecID: "tidb1"}}).Return(nil)
// used to break the loop
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1",
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1",
task.ID, task.Step, proto.SubtaskStateRunning).Return(nil, nil)
taskExecutor.checkBalanceSubtask(ctx)
require.True(t, ctrl.Satisfied())
@ -576,32 +576,32 @@ func TestExecutorErrHandling(t *testing.T) {
require.True(t, ctrl.Satisfied())
// 5. GetSubtasksByStepAndStates meet retryable error.
getSubtasksByStepAndStatesErr := errors.New("get subtasks err")
getSubtasksByExecIDAndStepAndStatesErr := errors.New("get subtasks err")
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(
gomock.Any(),
taskExecutor.id,
gomock.Any(),
proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByStepAndStatesErr)
unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByExecIDAndStepAndStatesErr)
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil)
mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(true)
require.NoError(t, taskExecutor.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}))
require.True(t, ctrl.Satisfied())
// 6. GetSubtasksByStepAndStates meet non retryable error.
// 6. GetSubtasksByExecIDAndStepAndStates meet non retryable error.
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(
gomock.Any(),
taskExecutor.id,
gomock.Any(),
proto.StepOne,
unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByStepAndStatesErr)
unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByExecIDAndStepAndStatesErr)
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil)
mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false)
mockSubtaskTable.EXPECT().FailSubtask(runCtx, taskExecutor.id, gomock.Any(), getSubtasksByStepAndStatesErr)
mockSubtaskTable.EXPECT().FailSubtask(runCtx, taskExecutor.id, gomock.Any(), getSubtasksByExecIDAndStepAndStatesErr)
require.NoError(t, taskExecutor.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}))
require.True(t, ctrl.Satisfied())
@ -609,7 +609,7 @@ func TestExecutorErrHandling(t *testing.T) {
cleanupErr := errors.New("cleanup err")
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(
gomock.Any(),
taskExecutor.id,
gomock.Any(),
@ -634,7 +634,7 @@ func TestExecutorErrHandling(t *testing.T) {
// 8. Cleanup meet non retryable error.
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(
gomock.Any(),
taskExecutor.id,
gomock.Any(),
@ -687,7 +687,7 @@ func TestExecutorErrHandling(t *testing.T) {
// 13. subtask succeed.
mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil)
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(
mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(
gomock.Any(),
taskExecutor.id,
gomock.Any(),

View File

@ -17,6 +17,7 @@ package testutil
import (
"context"
"fmt"
"strings"
"testing"
"time"
@ -79,6 +80,20 @@ func getTaskManager(t *testing.T, pool *pools.ResourcePool) *storage.TaskManager
return manager
}
// GetOneTask get a task from task table
func GetOneTask(ctx context.Context, mgr *storage.TaskManager) (task *proto.Task, err error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+storage.TaskColumns+" from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending)
if err != nil {
return task, err
}
if len(rs) == 0 {
return nil, nil
}
return storage.Row2Task(rs[0]), nil
}
// GetSubtasksFromHistory gets subtasks from history table for test.
func GetSubtasksFromHistory(ctx context.Context, mgr *storage.TaskManager) (int, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx,
@ -191,6 +206,47 @@ func TransferSubTasks2History(ctx context.Context, mgr *storage.TaskManager, tas
})
}
// GetTasksFromHistoryInStates gets the tasks in history table in the states.
func GetTasksFromHistoryInStates(ctx context.Context, mgr *storage.TaskManager, states ...interface{}) (task []*proto.Task, err error) {
if len(states) == 0 {
return task, nil
}
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+storage.TaskColumns+" from mysql.tidb_global_task_history where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...)
if err != nil {
return task, err
}
for _, r := range rs {
task = append(task, storage.Row2Task(r))
}
return task, nil
}
// DeleteSubtasksByTaskID deletes the subtask of the given task ID.
func DeleteSubtasksByTaskID(ctx context.Context, mgr *storage.TaskManager, taskID int64) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `delete from mysql.tidb_background_subtask
where task_key = %?`, taskID)
if err != nil {
return err
}
return nil
}
// IsTaskCancelling checks whether the task state is cancelling.
func IsTaskCancelling(ctx context.Context, mgr *storage.TaskManager, taskID int64) (bool, error) {
rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?",
taskID, proto.TaskStateCancelling,
)
if err != nil {
return false, err
}
return len(rs) > 0, nil
}
// PrintSubtaskInfo log the subtask info by taskKey for test.
func PrintSubtaskInfo(ctx context.Context, mgr *storage.TaskManager, taskID int64) {
rs, _ := mgr.ExecuteSQLWithNewSession(ctx,

View File

@ -769,7 +769,7 @@ func (s *mockGCSSuite) TestColumnsAndUserVars() {
taskManager := storage.NewTaskManager(pool)
ctx := context.Background()
ctx = util.WithInternalSourceType(ctx, "taskManager")
subtasks, err := taskManager.GetSubtasksByStepAndState(ctx, storage.TestLastTaskID.Load(), importinto.StepImport, proto.TaskStateSucceed)
subtasks, err := taskManager.GetAllSubtasksByStepAndState(ctx, storage.TestLastTaskID.Load(), importinto.StepImport, proto.SubtaskStateSucceed)
s.NoError(err)
s.Len(subtasks, 1)
serverInfo, err := infosync.GetServerInfo()