diff --git a/pkg/disttask/framework/framework_pause_and_resume_test.go b/pkg/disttask/framework/framework_pause_and_resume_test.go index 28a2ad92d3..2c5020014d 100644 --- a/pkg/disttask/framework/framework_pause_and_resume_test.go +++ b/pkg/disttask/framework/framework_pause_and_resume_test.go @@ -60,7 +60,7 @@ func TestFrameworkPauseAndResume(t *testing.T) { mgr, err := storage.GetTaskManager() require.NoError(t, err) - errs, err := mgr.CollectSubTaskError(ctx, 1) + errs, err := mgr.GetSubtaskErrors(ctx, 1) require.NoError(t, err) require.Empty(t, errs) @@ -77,7 +77,7 @@ func TestFrameworkPauseAndResume(t *testing.T) { CheckSubtasksState(ctx, t, 1, proto.SubtaskStateSucceed, 4) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/scheduler/syncAfterResume")) - errs, err = mgr.CollectSubTaskError(ctx, 1) + errs, err = mgr.GetSubtaskErrors(ctx, 1) require.NoError(t, err) require.Empty(t, errs) distContext.Close() diff --git a/pkg/disttask/framework/mock/scheduler_mock.go b/pkg/disttask/framework/mock/scheduler_mock.go index be08d82c7c..57002fbd44 100644 --- a/pkg/disttask/framework/mock/scheduler_mock.go +++ b/pkg/disttask/framework/mock/scheduler_mock.go @@ -251,21 +251,6 @@ func (mr *MockTaskManagerMockRecorder) CancelTask(arg0, arg1 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CancelTask", reflect.TypeOf((*MockTaskManager)(nil).CancelTask), arg0, arg1) } -// CollectSubTaskError mocks base method. -func (m *MockTaskManager) CollectSubTaskError(arg0 context.Context, arg1 int64) ([]error, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CollectSubTaskError", arg0, arg1) - ret0, _ := ret[0].([]error) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CollectSubTaskError indicates an expected call of CollectSubTaskError. -func (mr *MockTaskManagerMockRecorder) CollectSubTaskError(arg0, arg1 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CollectSubTaskError", reflect.TypeOf((*MockTaskManager)(nil).CollectSubTaskError), arg0, arg1) -} - // DeleteDeadNodes mocks base method. func (m *MockTaskManager) DeleteDeadNodes(arg0 context.Context, arg1 []string) error { m.ctrl.T.Helper() @@ -338,6 +323,21 @@ func (mr *MockTaskManagerMockRecorder) GetAllNodes(arg0 any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllNodes", reflect.TypeOf((*MockTaskManager)(nil).GetAllNodes), arg0) } +// GetAllSubtasksByStepAndState mocks base method. +func (m *MockTaskManager) GetAllSubtasksByStepAndState(arg0 context.Context, arg1 int64, arg2 proto.Step, arg3 proto.SubtaskState) ([]*proto.Subtask, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAllSubtasksByStepAndState", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].([]*proto.Subtask) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAllSubtasksByStepAndState indicates an expected call of GetAllSubtasksByStepAndState. +func (mr *MockTaskManagerMockRecorder) GetAllSubtasksByStepAndState(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllSubtasksByStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetAllSubtasksByStepAndState), arg0, arg1, arg2, arg3) +} + // GetManagedNodes mocks base method. func (m *MockTaskManager) GetManagedNodes(arg0 context.Context) ([]proto.ManagedNode, error) { m.ctrl.T.Helper() @@ -368,34 +368,19 @@ func (mr *MockTaskManagerMockRecorder) GetSubtaskCntGroupByStates(arg0, arg1, ar return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskCntGroupByStates", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskCntGroupByStates), arg0, arg1, arg2) } -// GetSubtasksByExecIdsAndStepAndState mocks base method. -func (m *MockTaskManager) GetSubtasksByExecIdsAndStepAndState(arg0 context.Context, arg1 []string, arg2 int64, arg3 proto.Step, arg4 proto.SubtaskState) ([]*proto.Subtask, error) { +// GetSubtaskErrors mocks base method. +func (m *MockTaskManager) GetSubtaskErrors(arg0 context.Context, arg1 int64) ([]error, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSubtasksByExecIdsAndStepAndState", arg0, arg1, arg2, arg3, arg4) - ret0, _ := ret[0].([]*proto.Subtask) + ret := m.ctrl.Call(m, "GetSubtaskErrors", arg0, arg1) + ret0, _ := ret[0].([]error) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSubtasksByExecIdsAndStepAndState indicates an expected call of GetSubtasksByExecIdsAndStepAndState. -func (mr *MockTaskManagerMockRecorder) GetSubtasksByExecIdsAndStepAndState(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call { +// GetSubtaskErrors indicates an expected call of GetSubtaskErrors. +func (mr *MockTaskManagerMockRecorder) GetSubtaskErrors(arg0, arg1 any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByExecIdsAndStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByExecIdsAndStepAndState), arg0, arg1, arg2, arg3, arg4) -} - -// GetSubtasksByStepAndState mocks base method. -func (m *MockTaskManager) GetSubtasksByStepAndState(arg0 context.Context, arg1 int64, arg2 proto.Step, arg3 proto.TaskState) ([]*proto.Subtask, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSubtasksByStepAndState", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].([]*proto.Subtask) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetSubtasksByStepAndState indicates an expected call of GetSubtasksByStepAndState. -func (mr *MockTaskManagerMockRecorder) GetSubtasksByStepAndState(arg0, arg1, arg2, arg3 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByStepAndState", reflect.TypeOf((*MockTaskManager)(nil).GetSubtasksByStepAndState), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtaskErrors", reflect.TypeOf((*MockTaskManager)(nil).GetSubtaskErrors), arg0, arg1) } // GetTaskByID mocks base method. @@ -428,21 +413,6 @@ func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskID(arg0, arg1 any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskID", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskID), arg0, arg1) } -// GetTaskExecutorIDsByTaskIDAndStep mocks base method. -func (m *MockTaskManager) GetTaskExecutorIDsByTaskIDAndStep(arg0 context.Context, arg1 int64, arg2 proto.Step) ([]string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetTaskExecutorIDsByTaskIDAndStep", arg0, arg1, arg2) - ret0, _ := ret[0].([]string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetTaskExecutorIDsByTaskIDAndStep indicates an expected call of GetTaskExecutorIDsByTaskIDAndStep. -func (mr *MockTaskManagerMockRecorder) GetTaskExecutorIDsByTaskIDAndStep(arg0, arg1, arg2 any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetTaskExecutorIDsByTaskIDAndStep", reflect.TypeOf((*MockTaskManager)(nil).GetTaskExecutorIDsByTaskIDAndStep), arg0, arg1, arg2) -} - // GetTasksInStates mocks base method. func (m *MockTaskManager) GetTasksInStates(arg0 context.Context, arg1 ...any) ([]*proto.Task, error) { m.ctrl.T.Helper() diff --git a/pkg/disttask/framework/mock/task_executor_mock.go b/pkg/disttask/framework/mock/task_executor_mock.go index c87393857f..725fe78d7b 100644 --- a/pkg/disttask/framework/mock/task_executor_mock.go +++ b/pkg/disttask/framework/mock/task_executor_mock.go @@ -102,24 +102,24 @@ func (mr *MockTaskTableMockRecorder) GetFirstSubtaskInStates(arg0, arg1, arg2, a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFirstSubtaskInStates", reflect.TypeOf((*MockTaskTable)(nil).GetFirstSubtaskInStates), varargs...) } -// GetSubtasksByStepAndStates mocks base method. -func (m *MockTaskTable) GetSubtasksByStepAndStates(arg0 context.Context, arg1 string, arg2 int64, arg3 proto.Step, arg4 ...proto.SubtaskState) ([]*proto.Subtask, error) { +// GetSubtasksByExecIDAndStepAndStates mocks base method. +func (m *MockTaskTable) GetSubtasksByExecIDAndStepAndStates(arg0 context.Context, arg1 string, arg2 int64, arg3 proto.Step, arg4 ...proto.SubtaskState) ([]*proto.Subtask, error) { m.ctrl.T.Helper() varargs := []any{arg0, arg1, arg2, arg3} for _, a := range arg4 { varargs = append(varargs, a) } - ret := m.ctrl.Call(m, "GetSubtasksByStepAndStates", varargs...) + ret := m.ctrl.Call(m, "GetSubtasksByExecIDAndStepAndStates", varargs...) ret0, _ := ret[0].([]*proto.Subtask) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSubtasksByStepAndStates indicates an expected call of GetSubtasksByStepAndStates. -func (mr *MockTaskTableMockRecorder) GetSubtasksByStepAndStates(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { +// GetSubtasksByExecIDAndStepAndStates indicates an expected call of GetSubtasksByExecIDAndStepAndStates. +func (mr *MockTaskTableMockRecorder) GetSubtasksByExecIDAndStepAndStates(arg0, arg1, arg2, arg3 any, arg4 ...any) *gomock.Call { mr.mock.ctrl.T.Helper() varargs := append([]any{arg0, arg1, arg2, arg3}, arg4...) - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByStepAndStates", reflect.TypeOf((*MockTaskTable)(nil).GetSubtasksByStepAndStates), varargs...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSubtasksByExecIDAndStepAndStates", reflect.TypeOf((*MockTaskTable)(nil).GetSubtasksByExecIDAndStepAndStates), varargs...) } // GetTaskByID mocks base method. diff --git a/pkg/disttask/framework/scheduler/interface.go b/pkg/disttask/framework/scheduler/interface.go index 0a3b1419d4..ec9b8f6d69 100644 --- a/pkg/disttask/framework/scheduler/interface.go +++ b/pkg/disttask/framework/scheduler/interface.go @@ -73,17 +73,18 @@ type TaskManager interface { // GetSubtaskCntGroupByStates returns the count of subtasks of some step group by state. GetSubtaskCntGroupByStates(ctx context.Context, taskID int64, step proto.Step) (map[proto.SubtaskState]int64, error) ResumeSubtasks(ctx context.Context, taskID int64) error - CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) + GetSubtaskErrors(ctx context.Context, taskID int64) ([]error, error) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error // GetManagedNodes returns the nodes managed by dist framework and can be used // to execute tasks. If there are any nodes with background role, we use them, // else we use nodes without role. // returned nodes are sorted by node id(host:port). GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) + // GetTaskExecutorIDsByTaskID gets the task executor IDs of the given task ID. GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) - GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) - GetSubtasksByExecIdsAndStepAndState(ctx context.Context, tidbIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) - GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error) + + // GetAllSubtasksByStepAndState gets all subtasks by given states for one step. + GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) WithNewSession(fn func(se sessionctx.Context) error) error WithNewTxn(ctx context.Context, fn func(se sessionctx.Context) error) error diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index c424f67bb4..761c053274 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -339,7 +339,7 @@ func (s *BaseScheduler) onRunning() error { return err } if cntByStates[proto.SubtaskStateFailed] > 0 || cntByStates[proto.SubtaskStateCanceled] > 0 { - subTaskErrs, err := s.taskMgr.CollectSubTaskError(s.ctx, task.ID) + subTaskErrs, err := s.taskMgr.GetSubtaskErrors(s.ctx, task.ID) if err != nil { logutil.Logger(s.logCtx).Warn("collect subtask error failed", zap.Error(err)) return err @@ -586,7 +586,7 @@ func (s *BaseScheduler) GetAllTaskExecutorIDs(ctx context.Context, task *proto.T // GetPreviousSubtaskMetas get subtask metas from specific step. func (s *BaseScheduler) GetPreviousSubtaskMetas(taskID int64, step proto.Step) ([][]byte, error) { - previousSubtasks, err := s.taskMgr.GetSubtasksByStepAndState(s.ctx, taskID, step, proto.TaskStateSucceed) + previousSubtasks, err := s.taskMgr.GetAllSubtasksByStepAndState(s.ctx, taskID, step, proto.SubtaskStateSucceed) if err != nil { logutil.Logger(s.logCtx).Warn("get previous succeed subtask failed", zap.Int64("step", int64(step))) return nil, err diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go index 1ebc2ec541..7988920d47 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -22,6 +22,7 @@ import ( "github.com/ngaut/pools" "github.com/pingcap/tidb/pkg/disttask/framework/mock" "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/testkit" "github.com/stretchr/testify/require" "github.com/tikv/client-go/v2/util" @@ -75,7 +76,7 @@ func TestCleanUpRoutine(t *testing.T) { } sch.DoCleanUpRoutine() require.Eventually(t, func() bool { - tasks, err := mgr.GetTasksFromHistoryInStates(ctx, proto.TaskStateSucceed) + tasks, err := testutil.GetTasksFromHistoryInStates(ctx, mgr, proto.TaskStateSucceed) require.NoError(t, err) return len(tasks) != 0 }, 5*time.Second*10, time.Millisecond*300) diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index 39f86420e0..752d499832 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -341,7 +341,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, if len(tasks) == taskCnt { break } - historyTasks, err := mgr.GetTasksFromHistoryInStates(ctx, expectedState) + historyTasks, err := testutil.GetTasksFromHistoryInStates(ctx, mgr, expectedState) require.NoError(t, err) if len(tasks)+len(historyTasks) == taskCnt { break diff --git a/pkg/disttask/framework/storage/BUILD.bazel b/pkg/disttask/framework/storage/BUILD.bazel index 4467186808..d68d574ff4 100644 --- a/pkg/disttask/framework/storage/BUILD.bazel +++ b/pkg/disttask/framework/storage/BUILD.bazel @@ -3,6 +3,10 @@ load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") go_library( name = "storage", srcs = [ + "converter.go", + "history.go", + "nodes.go", + "subtask_state.go", "task_state.go", "task_table.go", ], diff --git a/pkg/disttask/framework/storage/converter.go b/pkg/disttask/framework/storage/converter.go new file mode 100644 index 0000000000..4f43f27401 --- /dev/null +++ b/pkg/disttask/framework/storage/converter.go @@ -0,0 +1,115 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "strconv" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/util/chunk" + "github.com/pingcap/tidb/pkg/util/logutil" + "go.uber.org/zap" +) + +func row2TaskBasic(r chunk.Row) *proto.Task { + task := &proto.Task{ + ID: r.GetInt64(0), + Key: r.GetString(1), + Type: proto.TaskType(r.GetString(2)), + State: proto.TaskState(r.GetString(3)), + Step: proto.Step(r.GetInt64(4)), + Priority: int(r.GetInt64(5)), + Concurrency: int(r.GetInt64(6)), + } + task.CreateTime, _ = r.GetTime(7).GoTime(time.Local) + return task +} + +// Row2Task converts a row to a task. +func Row2Task(r chunk.Row) *proto.Task { + task := row2TaskBasic(r) + var startTime, updateTime time.Time + if !r.IsNull(8) { + startTime, _ = r.GetTime(8).GoTime(time.Local) + } + if !r.IsNull(9) { + updateTime, _ = r.GetTime(9).GoTime(time.Local) + } + task.StartTime = startTime + task.StateUpdateTime = updateTime + task.Meta = r.GetBytes(10) + task.SchedulerID = r.GetString(11) + if !r.IsNull(12) { + errBytes := r.GetBytes(12) + stdErr := errors.Normalize("") + err := stdErr.UnmarshalJSON(errBytes) + if err != nil { + logutil.BgLogger().Error("unmarshal task error", zap.Error(err)) + task.Error = errors.New(string(errBytes)) + } else { + task.Error = stdErr + } + } + return task +} + +// row2BasicSubTask converts a row to a subtask with basic info +func row2BasicSubTask(r chunk.Row) *proto.Subtask { + taskIDStr := r.GetString(2) + tid, err := strconv.Atoi(taskIDStr) + if err != nil { + logutil.BgLogger().Warn("unexpected subtask id", zap.String("subtask-id", taskIDStr)) + } + createTime, _ := r.GetTime(7).GoTime(time.Local) + var ordinal int + if !r.IsNull(8) { + ordinal = int(r.GetInt64(8)) + } + subtask := &proto.Subtask{ + ID: r.GetInt64(0), + Step: proto.Step(r.GetInt64(1)), + TaskID: int64(tid), + Type: proto.Int2Type(int(r.GetInt64(3))), + ExecID: r.GetString(4), + State: proto.SubtaskState(r.GetString(5)), + Concurrency: int(r.GetInt64(6)), + CreateTime: createTime, + Ordinal: ordinal, + } + return subtask +} + +// Row2SubTask converts a row to a subtask. +func Row2SubTask(r chunk.Row) *proto.Subtask { + subtask := row2BasicSubTask(r) + // subtask defines start/update time as bigint, to ensure backward compatible, + // we keep it that way, and we convert it here. + var startTime, updateTime time.Time + if !r.IsNull(9) { + ts := r.GetInt64(9) + startTime = time.Unix(ts, 0) + } + if !r.IsNull(10) { + ts := r.GetInt64(10) + updateTime = time.Unix(ts, 0) + } + subtask.StartTime = startTime + subtask.UpdateTime = updateTime + subtask.Meta = r.GetBytes(11) + subtask.Summary = r.GetJSON(12).String() + return subtask +} diff --git a/pkg/disttask/framework/storage/history.go b/pkg/disttask/framework/storage/history.go new file mode 100644 index 0000000000..75f4e8d066 --- /dev/null +++ b/pkg/disttask/framework/storage/history.go @@ -0,0 +1,94 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +// TransferSubtasks2HistoryWithSession transfer the selected subtasks into tidb_background_subtask_history table by taskID. +func (*TaskManager) TransferSubtasks2HistoryWithSession(ctx context.Context, se sessionctx.Context, taskID int64) error { + _, err := sqlexec.ExecSQL(ctx, se, `insert into mysql.tidb_background_subtask_history select * from mysql.tidb_background_subtask where task_key = %?`, taskID) + if err != nil { + return err + } + // delete taskID subtask + _, err = sqlexec.ExecSQL(ctx, se, "delete from mysql.tidb_background_subtask where task_key = %?", taskID) + return err +} + +// TransferTasks2History transfer the selected tasks into tidb_global_task_history table by taskIDs. +func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*proto.Task) error { + if len(tasks) == 0 { + return nil + } + taskIDStrs := make([]string, 0, len(tasks)) + for _, task := range tasks { + taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID)) + } + return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + // sensitive data in meta might be redacted, need update first. + for _, t := range tasks { + _, err := sqlexec.ExecSQL(ctx, se, ` + update mysql.tidb_global_task + set meta= %?, state_update_time = CURRENT_TIMESTAMP() + where id = %?`, t.Meta, t.ID) + if err != nil { + return err + } + } + _, err := sqlexec.ExecSQL(ctx, se, ` + insert into mysql.tidb_global_task_history + select * from mysql.tidb_global_task + where id in(`+strings.Join(taskIDStrs, `, `)+`)`) + if err != nil { + return err + } + + _, err = sqlexec.ExecSQL(ctx, se, ` + delete from mysql.tidb_global_task + where id in(`+strings.Join(taskIDStrs, `, `)+`)`) + + for _, t := range tasks { + err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID) + if err != nil { + return err + } + } + return err + }) +} + +// GCSubtasks deletes the history subtask which is older than the given days. +func (mgr *TaskManager) GCSubtasks(ctx context.Context) error { + subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60 + failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) { + if val, ok := val.(int); ok { + subtaskHistoryKeepSeconds = val + } + }) + _, err := mgr.ExecuteSQLWithNewSession( + ctx, + fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds), + ) + return err +} diff --git a/pkg/disttask/framework/storage/nodes.go b/pkg/disttask/framework/storage/nodes.go new file mode 100644 index 0000000000..a4b4eda415 --- /dev/null +++ b/pkg/disttask/framework/storage/nodes.go @@ -0,0 +1,203 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + "fmt" + "strings" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/cpu" + "github.com/pingcap/tidb/pkg/util/sqlescape" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +// InitMeta insert the manager information into dist_framework_meta. +func (mgr *TaskManager) InitMeta(ctx context.Context, tidbID string, role string) error { + return mgr.WithNewSession(func(se sessionctx.Context) error { + return mgr.InitMetaSession(ctx, se, tidbID, role) + }) +} + +// InitMetaSession insert the manager information into dist_framework_meta. +// if the record exists, update the cpu_count and role. +func (*TaskManager) InitMetaSession(ctx context.Context, se sessionctx.Context, execID string, role string) error { + cpuCount := cpu.GetCPUCount() + _, err := sqlexec.ExecSQL(ctx, se, ` + insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id) + values (%?, %?, %?, -1) + on duplicate key + update cpu_count = %?, role = %?`, + execID, role, cpuCount, cpuCount, role) + return err +} + +// RecoverMeta insert the manager information into dist_framework_meta. +// if the record exists, update the cpu_count. +// Don't update role for we only update it in `set global tidb_service_scope`. +// if not there might has a data race. +func (mgr *TaskManager) RecoverMeta(ctx context.Context, execID string, role string) error { + cpuCount := cpu.GetCPUCount() + _, err := mgr.ExecuteSQLWithNewSession(ctx, ` + insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id) + values (%?, %?, %?, -1) + on duplicate key + update cpu_count = %?`, + execID, role, cpuCount, cpuCount) + return err +} + +// DeleteDeadNodes deletes the dead nodes from mysql.dist_framework_meta. +func (mgr *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) error { + if len(nodes) == 0 { + return nil + } + return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + deleteSQL := new(strings.Builder) + if err := sqlescape.FormatSQL(deleteSQL, "delete from mysql.dist_framework_meta where host in("); err != nil { + return err + } + deleteElems := make([]string, 0, len(nodes)) + for _, node := range nodes { + deleteElems = append(deleteElems, fmt.Sprintf(`"%s"`, node)) + } + + deleteSQL.WriteString(strings.Join(deleteElems, ", ")) + deleteSQL.WriteString(")") + _, err := sqlexec.ExecSQL(ctx, se, deleteSQL.String()) + return err + }) +} + +// GetManagedNodes implements scheduler.TaskManager interface. +func (mgr *TaskManager) GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) { + var nodes []proto.ManagedNode + err := mgr.WithNewSession(func(se sessionctx.Context) error { + var err2 error + nodes, err2 = mgr.getManagedNodesWithSession(ctx, se) + return err2 + }) + return nodes, err +} + +func (mgr *TaskManager) getManagedNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) { + nodes, err := mgr.getAllNodesWithSession(ctx, se) + if err != nil { + return nil, err + } + nodeMap := make(map[string][]proto.ManagedNode, 2) + for _, node := range nodes { + nodeMap[node.Role] = append(nodeMap[node.Role], node) + } + if len(nodeMap["background"]) == 0 { + return nodeMap[""], nil + } + return nodeMap["background"], nil +} + +// GetAllNodes gets nodes in dist_framework_meta. +func (mgr *TaskManager) GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) { + var nodes []proto.ManagedNode + err := mgr.WithNewSession(func(se sessionctx.Context) error { + var err2 error + nodes, err2 = mgr.getAllNodesWithSession(ctx, se) + return err2 + }) + return nodes, err +} + +func (*TaskManager) getAllNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) { + rs, err := sqlexec.ExecSQL(ctx, se, ` + select host, role, cpu_count + from mysql.dist_framework_meta + order by host`) + if err != nil { + return nil, err + } + nodes := make([]proto.ManagedNode, 0, len(rs)) + for _, r := range rs { + nodes = append(nodes, proto.ManagedNode{ + ID: r.GetString(0), + Role: r.GetString(1), + CPUCount: int(r.GetInt64(2)), + }) + } + return nodes, nil +} + +// GetUsedSlotsOnNodes implements the scheduler.TaskManager interface. +func (mgr *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) { + // concurrency of subtasks of some step is the same, we use max(concurrency) + // to make group by works. + rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` + select + exec_id, sum(concurrency) + from ( + select exec_id, task_key, max(concurrency) concurrency + from mysql.tidb_background_subtask + where state in (%?, %?) + group by exec_id, task_key + ) a + group by exec_id`, + proto.SubtaskStatePending, proto.SubtaskStateRunning, + ) + if err != nil { + return nil, err + } + + slots := make(map[string]int, len(rs)) + for _, r := range rs { + val, _ := r.GetMyDecimal(1).ToInt() + slots[r.GetString(0)] = int(val) + } + return slots, nil +} + +// GetCPUCountOfManagedNode gets the cpu count of managed node. +func (mgr *TaskManager) GetCPUCountOfManagedNode(ctx context.Context) (int, error) { + var cnt int + err := mgr.WithNewSession(func(se sessionctx.Context) error { + var err2 error + cnt, err2 = mgr.getCPUCountOfManagedNode(ctx, se) + return err2 + }) + return cnt, err +} + +// getCPUCountOfManagedNode gets the cpu count of managed node. +// returns error when there's no managed node or no node has valid cpu count. +func (mgr *TaskManager) getCPUCountOfManagedNode(ctx context.Context, se sessionctx.Context) (int, error) { + nodes, err := mgr.getManagedNodesWithSession(ctx, se) + if err != nil { + return 0, err + } + if len(nodes) == 0 { + return 0, errors.New("no managed nodes") + } + var cpuCount int + for _, n := range nodes { + if n.CPUCount > 0 { + cpuCount = n.CPUCount + break + } + } + if cpuCount == 0 { + return 0, errors.New("no managed node have enough resource for dist task") + } + return cpuCount, nil +} diff --git a/pkg/disttask/framework/storage/subtask_state.go b/pkg/disttask/framework/storage/subtask_state.go new file mode 100644 index 0000000000..d3cd853fe5 --- /dev/null +++ b/pkg/disttask/framework/storage/subtask_state.go @@ -0,0 +1,147 @@ +// Copyright 2024 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/pingcap/tidb/pkg/disttask/framework/proto" + "github.com/pingcap/tidb/pkg/sessionctx" + "github.com/pingcap/tidb/pkg/util/sqlexec" +) + +// StartSubtask updates the subtask state to running. +func (mgr *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execID string) error { + err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + vars := se.GetSessionVars() + _, err := sqlexec.ExecSQL(ctx, + se, + `update mysql.tidb_background_subtask + set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() + where id = %? and exec_id = %?`, + proto.SubtaskStateRunning, + subtaskID, + execID) + if err != nil { + return err + } + if vars.StmtCtx.AffectedRows() == 0 { + return ErrSubtaskNotFound + } + return nil + }) + return err +} + +// FinishSubtask updates the subtask meta and mark state to succeed. +func (mgr *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask + set meta = %?, state = %?, state_update_time = unix_timestamp(), end_time = CURRENT_TIMESTAMP() + where id = %? and exec_id = %?`, + meta, proto.SubtaskStateSucceed, id, execID) + return err +} + +// FailSubtask update the task's subtask state to failed and set the err. +func (mgr *TaskManager) FailSubtask(ctx context.Context, execID string, taskID int64, err error) error { + if err == nil { + return nil + } + _, err1 := mgr.ExecuteSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask + set state = %?, + error = %?, + start_time = unix_timestamp(), + state_update_time = unix_timestamp(), + end_time = CURRENT_TIMESTAMP() + where exec_id = %? and + task_key = %? and + state in (%?, %?) + limit 1;`, + proto.SubtaskStateFailed, + serializeErr(err), + execID, + taskID, + proto.SubtaskStatePending, + proto.SubtaskStateRunning) + return err1 +} + +// CancelSubtask update the task's subtasks' state to canceled. +func (mgr *TaskManager) CancelSubtask(ctx context.Context, execID string, taskID int64) error { + _, err1 := mgr.ExecuteSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask + set state = %?, + start_time = unix_timestamp(), + state_update_time = unix_timestamp(), + end_time = CURRENT_TIMESTAMP() + where exec_id = %? and + task_key = %? and + state in (%?, %?) + limit 1;`, + proto.SubtaskStateCanceled, + execID, + taskID, + proto.SubtaskStatePending, + proto.SubtaskStateRunning) + return err1 +} + +// PauseSubtasks update all running/pending subtasks to pasued state. +func (mgr *TaskManager) PauseSubtasks(ctx context.Context, execID string, taskID int64) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, execID) + return err +} + +// ResumeSubtasks update all paused subtasks to pending state. +func (mgr *TaskManager) ResumeSubtasks(ctx context.Context, taskID int64) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, + `update mysql.tidb_background_subtask set state = "pending", error = null where task_key = %? and state = "paused"`, taskID) + return err +} + +// RunningSubtasksBack2Pending implements the taskexecutor.TaskTable interface. +func (mgr *TaskManager) RunningSubtasksBack2Pending(ctx context.Context, subtasks []*proto.Subtask) error { + // skip the update process. + if len(subtasks) == 0 { + return nil + } + err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { + for _, subtask := range subtasks { + _, err := sqlexec.ExecSQL(ctx, se, ` + update mysql.tidb_background_subtask + set state = %?, state_update_time = CURRENT_TIMESTAMP() + where id = %? and exec_id = %? and state = %?`, + proto.SubtaskStatePending, subtask.ID, subtask.ExecID, proto.SubtaskStateRunning) + if err != nil { + return err + } + } + return nil + }) + return err +} + +// UpdateSubtaskStateAndError updates the subtask state. +func (mgr *TaskManager) UpdateSubtaskStateAndError( + ctx context.Context, + execID string, + id int64, state proto.SubtaskState, subTaskErr error) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask + set state = %?, error = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`, + state, serializeErr(subTaskErr), id, execID) + return err +} diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index da8f5ce6cf..5e0050048f 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -63,7 +63,7 @@ func TestTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(1), id) - task, err := gm.GetOneTask(ctx) + task, err := testutil.GetOneTask(ctx, gm) require.NoError(t, err) require.Equal(t, int64(1), task.ID) require.Equal(t, "key1", task.Key) @@ -117,12 +117,12 @@ func TestTaskTable(t *testing.T) { id, err = gm.CreateTask(ctx, "key2", "test", 4, []byte("test")) require.NoError(t, err) - cancelling, err := gm.IsTaskCancelling(ctx, id) + cancelling, err := testutil.IsTaskCancelling(ctx, gm, id) require.NoError(t, err) require.False(t, cancelling) require.NoError(t, gm.CancelTask(ctx, id)) - cancelling, err = gm.IsTaskCancelling(ctx, id) + cancelling, err = testutil.IsTaskCancelling(ctx, gm, id) require.NoError(t, err) require.True(t, cancelling) @@ -212,7 +212,7 @@ func checkAfterSwitchStep(t *testing.T, startTime time.Time, task *proto.Task, s checkTaskStateStep(t, task, proto.TaskStateRunning, step) require.GreaterOrEqual(t, task.StartTime, startTime) require.GreaterOrEqual(t, task.StateUpdateTime, startTime) - gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task.ID, task.Step, proto.TaskStatePending) + gotSubtasks, err := tm.GetAllSubtasksByStepAndState(ctx, task.ID, task.Step, proto.SubtaskStatePending) require.NoError(t, err) require.Len(t, gotSubtasks, len(subtasks)) sort.Slice(gotSubtasks, func(i, j int) bool { @@ -329,7 +329,7 @@ func TestSwitchTaskStepInBatch(t *testing.T) { task2, err = tm.GetTaskByID(ctx, task2.ID) require.NoError(t, err) checkTaskStateStep(t, task2, proto.TaskStatePending, proto.StepInit) - gotSubtasks, err := tm.GetSubtasksByStepAndState(ctx, task2.ID, proto.StepOne, proto.TaskStatePending) + gotSubtasks, err := tm.GetAllSubtasksByStepAndState(ctx, task2.ID, proto.StepOne, proto.SubtaskStatePending) require.NoError(t, err) require.Len(t, gotSubtasks, 1) // run again, should success @@ -541,8 +541,7 @@ func TestSubTaskTable(t *testing.T) { ts := time.Now() time.Sleep(time.Second) - err = sm.StartSubtask(ctx, 1, "tidb1") - require.NoError(t, err) + require.NoError(t, sm.StartSubtask(ctx, 1, "tidb1")) err = sm.StartSubtask(ctx, 1, "tidb2") require.Error(t, storage.ErrSubtaskNotFound, err) @@ -577,8 +576,7 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.False(t, ok) - err = sm.DeleteSubtasksByTaskID(ctx, 1) - require.NoError(t, err) + require.NoError(t, testutil.DeleteSubtasksByTaskID(ctx, sm, 1)) ok, err = sm.HasSubtasksInStates(ctx, "tidb1", 1, proto.StepInit, proto.SubtaskStatePending, proto.SubtaskStateRunning) require.NoError(t, err) @@ -590,22 +588,20 @@ func TestSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(1), cntByStates[proto.SubtaskStateRevertPending]) - subtasks, err := sm.GetSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.TaskStateSucceed) + subtasks, err := sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.SubtaskStateSucceed) require.NoError(t, err) require.Len(t, subtasks, 0) - err = sm.FinishSubtask(ctx, "tidb1", 2, []byte{}) - require.NoError(t, err) + require.NoError(t, sm.FinishSubtask(ctx, "tidb1", 2, []byte{})) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.TaskStateSucceed) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 2, proto.StepInit, proto.SubtaskStateSucceed) require.NoError(t, err) require.Len(t, subtasks, 1) rowCount, err := sm.GetSubtaskRowCount(ctx, 2, proto.StepInit) require.NoError(t, err) require.Equal(t, int64(0), rowCount) - err = sm.UpdateSubtaskRowCount(ctx, 2, 100) - require.NoError(t, err) + require.NoError(t, sm.UpdateSubtaskRowCount(ctx, 2, 100)) rowCount, err = sm.GetSubtaskRowCount(ctx, 2, proto.StepInit) require.NoError(t, err) require.Equal(t, int64(100), rowCount) @@ -613,62 +609,42 @@ func TestSubTaskTable(t *testing.T) { // test UpdateSubtasksExecIDs // 1. update one subtask testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb2" require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) // 2. update 2 subtasks testutil.CreateSubTask(t, sm, 5, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) subtasks[0].ExecID = "tidb3" require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb3", subtasks[0].ExecID) require.Equal(t, "tidb1", subtasks[1].ExecID) // update fail require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateRunning, nil)) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb3", subtasks[0].ExecID) subtasks[0].ExecID = "tidb2" // update success require.NoError(t, sm.UpdateSubtasksExecIDs(ctx, subtasks)) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 5, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, "tidb2", subtasks[0].ExecID) - // test GetSubtasksByExecIdsAndStepAndState - testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 6, proto.StepInit, proto.TaskStatePending) - require.NoError(t, err) - require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateRunning, nil)) - subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.SubtaskStateRunning) - require.NoError(t, err) - require.Equal(t, 1, len(subtasks)) - subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.SubtaskStatePending) - require.NoError(t, err) - require.Equal(t, 1, len(subtasks)) - testutil.CreateSubTask(t, sm, 6, proto.StepInit, "tidb2", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1", "tidb2"}, 6, proto.StepInit, proto.SubtaskStatePending) - require.NoError(t, err) - require.Equal(t, 2, len(subtasks)) - subtasks, err = sm.GetSubtasksByExecIdsAndStepAndState(ctx, []string{"tidb1"}, 6, proto.StepInit, proto.SubtaskStatePending) - require.NoError(t, err) - require.Equal(t, 1, len(subtasks)) - - // test CollectSubTaskError + // test GetSubtaskErrors testutil.CreateSubTask(t, sm, 7, proto.StepInit, "tidb1", []byte("test"), proto.TaskTypeExample, 11, false) - subtasks, err = sm.GetSubtasksByStepAndState(ctx, 7, proto.StepInit, proto.TaskStatePending) + subtasks, err = sm.GetAllSubtasksByStepAndState(ctx, 7, proto.StepInit, proto.SubtaskStatePending) require.NoError(t, err) require.Equal(t, 1, len(subtasks)) require.NoError(t, sm.UpdateSubtaskStateAndError(ctx, "tidb1", subtasks[0].ID, proto.SubtaskStateFailed, errors.New("test err"))) - subtaskErrs, err := sm.CollectSubTaskError(ctx, 7) + subtaskErrs, err := sm.GetSubtaskErrors(ctx, 7) require.NoError(t, err) require.Equal(t, 1, len(subtaskErrs)) require.ErrorContains(t, subtaskErrs[0], "test err") @@ -681,7 +657,7 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(1), id) - task, err := sm.GetOneTask(ctx) + task, err := testutil.GetOneTask(ctx, sm) require.NoError(t, err) require.Equal(t, proto.TaskStatePending, task.State) @@ -769,7 +745,7 @@ func TestBothTaskAndSubTaskTable(t *testing.T) { require.Equal(t, int64(2), cntByStates[proto.SubtaskStateRevertPending]) // test transactional - require.NoError(t, sm.DeleteSubtasksByTaskID(ctx, 1)) + require.NoError(t, testutil.DeleteSubtasksByTaskID(ctx, sm, 1)) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr", "1*return(true)")) defer func() { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/storage/MockUpdateTaskErr")) diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 0f735c329c..524212b903 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -16,11 +16,9 @@ package storage import ( "context" - "fmt" "strconv" "strings" "sync/atomic" - "time" "github.com/docker/go-units" "github.com/ngaut/pools" @@ -31,21 +29,19 @@ import ( "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/util/chunk" - "github.com/pingcap/tidb/pkg/util/cpu" "github.com/pingcap/tidb/pkg/util/intest" - "github.com/pingcap/tidb/pkg/util/logutil" "github.com/pingcap/tidb/pkg/util/sqlescape" "github.com/pingcap/tidb/pkg/util/sqlexec" "github.com/tikv/client-go/v2/util" - "go.uber.org/zap" ) const ( defaultSubtaskKeepDays = 14 basicTaskColumns = `id, task_key, type, state, step, priority, concurrency, create_time` + // TaskColumns is the columns for task. // TODO: dispatcher_id will update to scheduler_id later - taskColumns = basicTaskColumns + `, start_time, state_update_time, meta, dispatcher_id, error` + TaskColumns = basicTaskColumns + `, start_time, state_update_time, meta, dispatcher_id, error` // InsertTaskColumns is the columns used in insert task. InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time` basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal` @@ -132,48 +128,6 @@ func SetTaskManager(is *TaskManager) { taskManagerInstance.Store(is) } -func row2TaskBasic(r chunk.Row) *proto.Task { - task := &proto.Task{ - ID: r.GetInt64(0), - Key: r.GetString(1), - Type: proto.TaskType(r.GetString(2)), - State: proto.TaskState(r.GetString(3)), - Step: proto.Step(r.GetInt64(4)), - Priority: int(r.GetInt64(5)), - Concurrency: int(r.GetInt64(6)), - } - task.CreateTime, _ = r.GetTime(7).GoTime(time.Local) - return task -} - -// row2Task converts a row to a task. -func row2Task(r chunk.Row) *proto.Task { - task := row2TaskBasic(r) - var startTime, updateTime time.Time - if !r.IsNull(8) { - startTime, _ = r.GetTime(8).GoTime(time.Local) - } - if !r.IsNull(9) { - updateTime, _ = r.GetTime(9).GoTime(time.Local) - } - task.StartTime = startTime - task.StateUpdateTime = updateTime - task.Meta = r.GetBytes(10) - task.SchedulerID = r.GetString(11) - if !r.IsNull(12) { - errBytes := r.GetBytes(12) - stdErr := errors.Normalize("") - err := stdErr.UnmarshalJSON(errBytes) - if err != nil { - logutil.BgLogger().Error("unmarshal task error", zap.Error(err)) - task.Error = errors.New(string(errBytes)) - } else { - task.Error = stdErr - } - } - return task -} - // WithNewSession executes the function with a new session. func (mgr *TaskManager) WithNewSession(fn func(se sessionctx.Context) error) error { se, err := mgr.sePool.Get() @@ -266,20 +220,6 @@ func (mgr *TaskManager) CreateTaskWithSession(ctx context.Context, se sessionctx return taskID, nil } -// GetOneTask get a task from task table, it's used by scheduler only. -func (mgr *TaskManager) GetOneTask(ctx context.Context) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending) - if err != nil { - return task, err - } - - if len(rs) == 0 { - return nil, nil - } - - return row2Task(rs[0]), nil -} - // GetTopUnfinishedTasks implements the scheduler.TaskManager interface. func (mgr *TaskManager) GetTopUnfinishedTasks(ctx context.Context) (task []*proto.Task, err error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, @@ -312,7 +252,7 @@ func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...interfac } rs, err := mgr.ExecuteSQLWithNewSession(ctx, - "select "+taskColumns+" from mysql.tidb_global_task "+ + "select "+TaskColumns+" from mysql.tidb_global_task "+ "where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)"+ " order by priority asc, create_time asc, id asc", states...) if err != nil { @@ -320,31 +260,14 @@ func (mgr *TaskManager) GetTasksInStates(ctx context.Context, states ...interfac } for _, r := range rs { - task = append(task, row2Task(r)) - } - return task, nil -} - -// GetTasksFromHistoryInStates gets the tasks in history table in the states. -func (mgr *TaskManager) GetTasksFromHistoryInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) { - if len(states) == 0 { - return task, nil - } - - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task_history where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...) - if err != nil { - return task, err - } - - for _, r := range rs { - task = append(task, row2Task(r)) + task = append(task, Row2Task(r)) } return task, nil } // GetTaskByID gets the task by the task ID. func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where id = %?", taskID) + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where id = %?", taskID) if err != nil { return task, err } @@ -352,13 +275,13 @@ func (mgr *TaskManager) GetTaskByID(ctx context.Context, taskID int64) (task *pr return nil, ErrTaskNotFound } - return row2Task(rs[0]), nil + return Row2Task(rs[0]), nil } // GetTaskByIDWithHistory gets the task by the task ID from both tidb_global_task and tidb_global_task_history. func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where id = %? "+ - "union select "+taskColumns+" from mysql.tidb_global_task_history where id = %?", taskID, taskID) + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where id = %? "+ + "union select "+TaskColumns+" from mysql.tidb_global_task_history where id = %?", taskID, taskID) if err != nil { return task, err } @@ -366,12 +289,12 @@ func (mgr *TaskManager) GetTaskByIDWithHistory(ctx context.Context, taskID int64 return nil, ErrTaskNotFound } - return row2Task(rs[0]), nil + return Row2Task(rs[0]), nil } // GetTaskByKey gets the task by the task key. func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where task_key = %?", key) + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where task_key = %?", key) if err != nil { return task, err } @@ -379,13 +302,13 @@ func (mgr *TaskManager) GetTaskByKey(ctx context.Context, key string) (task *pro return nil, ErrTaskNotFound } - return row2Task(rs[0]), nil + return Row2Task(rs[0]), nil } // GetTaskByKeyWithHistory gets the task from history table by the task key. func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) (task *proto.Task, err error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+taskColumns+" from mysql.tidb_global_task where task_key = %?"+ - "union select "+taskColumns+" from mysql.tidb_global_task_history where task_key = %?", key, key) + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+TaskColumns+" from mysql.tidb_global_task where task_key = %?"+ + "union select "+TaskColumns+" from mysql.tidb_global_task_history where task_key = %?", key, key) if err != nil { return task, err } @@ -393,87 +316,12 @@ func (mgr *TaskManager) GetTaskByKeyWithHistory(ctx context.Context, key string) return nil, ErrTaskNotFound } - return row2Task(rs[0]), nil + return Row2Task(rs[0]), nil } -// GetUsedSlotsOnNodes implements the scheduler.TaskManager interface. -func (mgr *TaskManager) GetUsedSlotsOnNodes(ctx context.Context) (map[string]int, error) { - // concurrency of subtasks of some step is the same, we use max(concurrency) - // to make group by works. - rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` - select - exec_id, sum(concurrency) - from ( - select exec_id, task_key, max(concurrency) concurrency - from mysql.tidb_background_subtask - where state in (%?, %?) - group by exec_id, task_key - ) a - group by exec_id`, - proto.TaskStatePending, proto.TaskStateRunning, - ) - if err != nil { - return nil, err - } - - slots := make(map[string]int, len(rs)) - for _, r := range rs { - val, _ := r.GetMyDecimal(1).ToInt() - slots[r.GetString(0)] = int(val) - } - return slots, nil -} - -// row2BasicSubTask converts a row to a subtask with basic info -func row2BasicSubTask(r chunk.Row) *proto.Subtask { - taskIDStr := r.GetString(2) - tid, err := strconv.Atoi(taskIDStr) - if err != nil { - logutil.BgLogger().Warn("unexpected subtask id", zap.String("subtask-id", taskIDStr)) - } - createTime, _ := r.GetTime(7).GoTime(time.Local) - var ordinal int - if !r.IsNull(8) { - ordinal = int(r.GetInt64(8)) - } - subtask := &proto.Subtask{ - ID: r.GetInt64(0), - Step: proto.Step(r.GetInt64(1)), - TaskID: int64(tid), - Type: proto.Int2Type(int(r.GetInt64(3))), - ExecID: r.GetString(4), - State: proto.SubtaskState(r.GetString(5)), - Concurrency: int(r.GetInt64(6)), - CreateTime: createTime, - Ordinal: ordinal, - } - return subtask -} - -// Row2SubTask converts a row to a subtask. -func Row2SubTask(r chunk.Row) *proto.Subtask { - subtask := row2BasicSubTask(r) - // subtask defines start/update time as bigint, to ensure backward compatible, - // we keep it that way, and we convert it here. - var startTime, updateTime time.Time - if !r.IsNull(9) { - ts := r.GetInt64(9) - startTime = time.Unix(ts, 0) - } - if !r.IsNull(10) { - ts := r.GetInt64(10) - updateTime = time.Unix(ts, 0) - } - subtask.StartTime = startTime - subtask.UpdateTime = updateTime - subtask.Meta = r.GetBytes(11) - subtask.Summary = r.GetJSON(12).String() - return subtask -} - -// GetSubtasksByStepAndStates gets all subtasks by given states. -func (mgr *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) { - args := []interface{}{tidbID, taskID, step} +// GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states on one node. +func (mgr *TaskManager) GetSubtasksByExecIDAndStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) { + args := []interface{}{execID, taskID, step} for _, state := range states { args = append(args, state) } @@ -491,26 +339,6 @@ func (mgr *TaskManager) GetSubtasksByStepAndStates(ctx context.Context, tidbID s return subtasks, nil } -// GetSubtasksByExecIdsAndStepAndState gets all subtasks by given taskID, exec_id, step and state. -func (mgr *TaskManager) GetSubtasksByExecIdsAndStepAndState(ctx context.Context, execIDs []string, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) { - args := []interface{}{taskID, step, state} - for _, execID := range execIDs { - args = append(args, execID) - } - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask - where task_key = %? and step = %? and state = %? - and exec_id in (`+strings.Repeat("%?,", len(execIDs)-1)+"%?)", args...) - if err != nil { - return nil, err - } - - subtasks := make([]*proto.Subtask, len(rs)) - for i, row := range rs { - subtasks[i] = Row2SubTask(row) - } - return subtasks, nil -} - // GetFirstSubtaskInStates gets the first subtask by given states. func (mgr *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (*proto.Subtask, error) { args := []interface{}{tidbID, taskID, step} @@ -530,51 +358,6 @@ func (mgr *TaskManager) GetFirstSubtaskInStates(ctx context.Context, tidbID stri return Row2SubTask(rs[0]), nil } -// FailSubtask update the task's subtask state to failed and set the err. -func (mgr *TaskManager) FailSubtask(ctx context.Context, execID string, taskID int64, err error) error { - if err == nil { - return nil - } - _, err1 := mgr.ExecuteSQLWithNewSession(ctx, - `update mysql.tidb_background_subtask - set state = %?, - error = %?, - start_time = unix_timestamp(), - state_update_time = unix_timestamp(), - end_time = CURRENT_TIMESTAMP() - where exec_id = %? and - task_key = %? and - state in (%?, %?) - limit 1;`, - proto.SubtaskStateFailed, - serializeErr(err), - execID, - taskID, - proto.SubtaskStatePending, - proto.SubtaskStateRunning) - return err1 -} - -// CancelSubtask update the task's subtasks' state to canceled. -func (mgr *TaskManager) CancelSubtask(ctx context.Context, execID string, taskID int64) error { - _, err1 := mgr.ExecuteSQLWithNewSession(ctx, - `update mysql.tidb_background_subtask - set state = %?, - start_time = unix_timestamp(), - state_update_time = unix_timestamp(), - end_time = CURRENT_TIMESTAMP() - where exec_id = %? and - task_key = %? and - state in (%?, %?) - limit 1;`, - proto.SubtaskStateCanceled, - execID, - taskID, - proto.SubtaskStatePending, - proto.SubtaskStateRunning) - return err1 -} - // GetActiveSubtasks implements TaskManager.GetActiveSubtasks. func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([]*proto.Subtask, error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, ` @@ -591,8 +374,8 @@ func (mgr *TaskManager) GetActiveSubtasks(ctx context.Context, taskID int64) ([] return subtasks, nil } -// GetSubtasksByStepAndState gets the subtask by step and state. -func (mgr *TaskManager) GetSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.TaskState) ([]*proto.Subtask, error) { +// GetAllSubtasksByStepAndState gets the subtask by step and state. +func (mgr *TaskManager) GetAllSubtasksByStepAndState(ctx context.Context, taskID int64, step proto.Step, state proto.SubtaskState) ([]*proto.Subtask, error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select `+SubtaskColumns+` from mysql.tidb_background_subtask where task_key = %? and state = %? and step = %?`, taskID, state, step) @@ -654,8 +437,8 @@ func (mgr *TaskManager) GetSubtaskCntGroupByStates(ctx context.Context, taskID i return res, nil } -// CollectSubTaskError collects the subtask error. -func (mgr *TaskManager) CollectSubTaskError(ctx context.Context, taskID int64) ([]error, error) { +// GetSubtaskErrors gets subtasks' errors. +func (mgr *TaskManager) GetSubtaskErrors(ctx context.Context, taskID int64) ([]error, error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select error from mysql.tidb_background_subtask where task_key = %? AND state in (%?, %?)`, taskID, proto.SubtaskStateFailed, proto.SubtaskStateCanceled) @@ -700,95 +483,6 @@ func (mgr *TaskManager) HasSubtasksInStates(ctx context.Context, tidbID string, return len(rs) > 0, nil } -// StartSubtask updates the subtask state to running. -func (mgr *TaskManager) StartSubtask(ctx context.Context, subtaskID int64, execID string) error { - err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - vars := se.GetSessionVars() - _, err := sqlexec.ExecSQL(ctx, - se, - `update mysql.tidb_background_subtask - set state = %?, start_time = unix_timestamp(), state_update_time = unix_timestamp() - where id = %? and exec_id = %?`, - proto.TaskStateRunning, - subtaskID, - execID) - if err != nil { - return err - } - if vars.StmtCtx.AffectedRows() == 0 { - return ErrSubtaskNotFound - } - return nil - }) - return err -} - -// InitMeta insert the manager information into dist_framework_meta. -func (mgr *TaskManager) InitMeta(ctx context.Context, tidbID string, role string) error { - return mgr.WithNewSession(func(se sessionctx.Context) error { - return mgr.InitMetaSession(ctx, se, tidbID, role) - }) -} - -// InitMetaSession insert the manager information into dist_framework_meta. -// if the record exists, update the cpu_count and role. -func (*TaskManager) InitMetaSession(ctx context.Context, se sessionctx.Context, execID string, role string) error { - cpuCount := cpu.GetCPUCount() - _, err := sqlexec.ExecSQL(ctx, se, ` - insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id) - values (%?, %?, %?, -1) - on duplicate key - update cpu_count = %?, role = %?`, - execID, role, cpuCount, cpuCount, role) - return err -} - -// RecoverMeta insert the manager information into dist_framework_meta. -// if the record exists, update the cpu_count. -// Don't update role for we only update it in `set global tidb_service_scope`. -// if not there might has a data race. -func (mgr *TaskManager) RecoverMeta(ctx context.Context, execID string, role string) error { - cpuCount := cpu.GetCPUCount() - _, err := mgr.ExecuteSQLWithNewSession(ctx, ` - insert into mysql.dist_framework_meta(host, role, cpu_count, keyspace_id) - values (%?, %?, %?, -1) - on duplicate key - update cpu_count = %?`, - execID, role, cpuCount, cpuCount) - return err -} - -// UpdateSubtaskStateAndError updates the subtask state. -func (mgr *TaskManager) UpdateSubtaskStateAndError( - ctx context.Context, - execID string, - id int64, state proto.SubtaskState, subTaskErr error) error { - _, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask - set state = %?, error = %?, state_update_time = unix_timestamp() where id = %? and exec_id = %?`, - state, serializeErr(subTaskErr), id, execID) - return err -} - -// FinishSubtask updates the subtask meta and mark state to succeed. -func (mgr *TaskManager) FinishSubtask(ctx context.Context, execID string, id int64, meta []byte) error { - _, err := mgr.ExecuteSQLWithNewSession(ctx, `update mysql.tidb_background_subtask - set meta = %?, state = %?, state_update_time = unix_timestamp(), end_time = CURRENT_TIMESTAMP() - where id = %? and exec_id = %?`, - meta, proto.TaskStateSucceed, id, execID) - return err -} - -// DeleteSubtasksByTaskID deletes the subtask of the given task ID. -func (mgr *TaskManager) DeleteSubtasksByTaskID(ctx context.Context, taskID int64) error { - _, err := mgr.ExecuteSQLWithNewSession(ctx, `delete from mysql.tidb_background_subtask - where task_key = %?`, taskID) - if err != nil { - return err - } - - return nil -} - // GetTaskExecutorIDsByTaskID gets the task executor IDs of the given task ID. func (mgr *TaskManager) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID int64) ([]string, error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select distinct(exec_id) from mysql.tidb_background_subtask @@ -809,26 +503,6 @@ func (mgr *TaskManager) GetTaskExecutorIDsByTaskID(ctx context.Context, taskID i return instanceIDs, nil } -// GetTaskExecutorIDsByTaskIDAndStep gets the task executor IDs of the given global task ID and step. -func (mgr *TaskManager) GetTaskExecutorIDsByTaskIDAndStep(ctx context.Context, taskID int64, step proto.Step) ([]string, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, `select distinct(exec_id) from mysql.tidb_background_subtask - where task_key = %? and step = %?`, taskID, step) - if err != nil { - return nil, err - } - if len(rs) == 0 { - return nil, nil - } - - instanceIDs := make([]string, 0, len(rs)) - for _, r := range rs { - id := r.GetString(0) - instanceIDs = append(instanceIDs, id) - } - - return instanceIDs, nil -} - // UpdateSubtasksExecIDs update subtasks' execID. func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*proto.Subtask) error { // skip the update process. @@ -851,64 +525,6 @@ func (mgr *TaskManager) UpdateSubtasksExecIDs(ctx context.Context, subtasks []*p return err } -// RunningSubtasksBack2Pending implements the taskexecutor.TaskTable interface. -func (mgr *TaskManager) RunningSubtasksBack2Pending(ctx context.Context, subtasks []*proto.Subtask) error { - // skip the update process. - if len(subtasks) == 0 { - return nil - } - err := mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - for _, subtask := range subtasks { - _, err := sqlexec.ExecSQL(ctx, se, ` - update mysql.tidb_background_subtask - set state = %?, state_update_time = CURRENT_TIMESTAMP() - where id = %? and exec_id = %? and state = %?`, - proto.SubtaskStatePending, subtask.ID, subtask.ExecID, proto.SubtaskStateRunning) - if err != nil { - return err - } - } - return nil - }) - return err -} - -// DeleteDeadNodes deletes the dead nodes from mysql.dist_framework_meta. -func (mgr *TaskManager) DeleteDeadNodes(ctx context.Context, nodes []string) error { - if len(nodes) == 0 { - return nil - } - return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - deleteSQL := new(strings.Builder) - if err := sqlescape.FormatSQL(deleteSQL, "delete from mysql.dist_framework_meta where host in("); err != nil { - return err - } - deleteElems := make([]string, 0, len(nodes)) - for _, node := range nodes { - deleteElems = append(deleteElems, fmt.Sprintf(`"%s"`, node)) - } - - deleteSQL.WriteString(strings.Join(deleteElems, ", ")) - deleteSQL.WriteString(")") - _, err := sqlexec.ExecSQL(ctx, se, deleteSQL.String()) - return err - }) -} - -// PauseSubtasks update all running/pending subtasks to pasued state. -func (mgr *TaskManager) PauseSubtasks(ctx context.Context, execID string, taskID int64) error { - _, err := mgr.ExecuteSQLWithNewSession(ctx, - `update mysql.tidb_background_subtask set state = "paused" where task_key = %? and state in ("running", "pending") and exec_id = %?`, taskID, execID) - return err -} - -// ResumeSubtasks update all paused subtasks to pending state. -func (mgr *TaskManager) ResumeSubtasks(ctx context.Context, taskID int64) error { - _, err := mgr.ExecuteSQLWithNewSession(ctx, - `update mysql.tidb_background_subtask set state = "pending", error = null where task_key = %? and state = "paused"`, taskID) - return err -} - // SwitchTaskStep implements the dispatcher.TaskManager interface. func (mgr *TaskManager) SwitchTaskStep( ctx context.Context, @@ -983,7 +599,7 @@ func (*TaskManager) insertSubtasks(ctx context.Context, se sessionctx.Context, s for _, subtask := range subtasks { markerList = append(markerList, "(%?, %?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), '{}', '{}')") args = append(args, subtask.Step, subtask.TaskID, subtask.ExecID, subtask.Meta, - proto.TaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) + proto.SubtaskStatePending, proto.Type2Int(subtask.Type), subtask.Concurrency, subtask.Ordinal) } sb.WriteString(strings.Join(markerList, ",")) _, err := sqlexec.ExecSQL(ctx, se, sb.String(), args...) @@ -1135,19 +751,6 @@ func serializeErr(err error) []byte { return errBytes } -// IsTaskCancelling checks whether the task state is cancelling. -func (mgr *TaskManager) IsTaskCancelling(ctx context.Context, taskID int64) (bool, error) { - rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?", - taskID, proto.TaskStateCancelling, - ) - - if err != nil { - return false, err - } - - return len(rs) > 0, nil -} - // GetSubtasksWithHistory gets the subtasks from tidb_global_task and tidb_global_task_history. func (mgr *TaskManager) GetSubtasksWithHistory(ctx context.Context, taskID int64, step proto.Step) ([]*proto.Subtask, error) { var ( @@ -1189,161 +792,3 @@ func (mgr *TaskManager) GetSubtasksWithHistory(ctx context.Context, taskID int64 } return subtasks, nil } - -// TransferSubtasks2HistoryWithSession transfer the selected subtasks into tidb_background_subtask_history table by taskID. -func (*TaskManager) TransferSubtasks2HistoryWithSession(ctx context.Context, se sessionctx.Context, taskID int64) error { - _, err := sqlexec.ExecSQL(ctx, se, `insert into mysql.tidb_background_subtask_history select * from mysql.tidb_background_subtask where task_key = %?`, taskID) - if err != nil { - return err - } - // delete taskID subtask - _, err = sqlexec.ExecSQL(ctx, se, "delete from mysql.tidb_background_subtask where task_key = %?", taskID) - return err -} - -// TransferTasks2History transfer the selected tasks into tidb_global_task_history table by taskIDs. -func (mgr *TaskManager) TransferTasks2History(ctx context.Context, tasks []*proto.Task) error { - if len(tasks) == 0 { - return nil - } - taskIDStrs := make([]string, 0, len(tasks)) - for _, task := range tasks { - taskIDStrs = append(taskIDStrs, fmt.Sprintf("%d", task.ID)) - } - return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error { - // sensitive data in meta might be redacted, need update first. - for _, t := range tasks { - _, err := sqlexec.ExecSQL(ctx, se, ` - update mysql.tidb_global_task - set meta= %?, state_update_time = CURRENT_TIMESTAMP() - where id = %?`, t.Meta, t.ID) - if err != nil { - return err - } - } - _, err := sqlexec.ExecSQL(ctx, se, ` - insert into mysql.tidb_global_task_history - select * from mysql.tidb_global_task - where id in(`+strings.Join(taskIDStrs, `, `)+`)`) - if err != nil { - return err - } - - _, err = sqlexec.ExecSQL(ctx, se, ` - delete from mysql.tidb_global_task - where id in(`+strings.Join(taskIDStrs, `, `)+`)`) - - for _, t := range tasks { - err = mgr.TransferSubtasks2HistoryWithSession(ctx, se, t.ID) - if err != nil { - return err - } - } - return err - }) -} - -// GCSubtasks deletes the history subtask which is older than the given days. -func (mgr *TaskManager) GCSubtasks(ctx context.Context) error { - subtaskHistoryKeepSeconds := defaultSubtaskKeepDays * 24 * 60 * 60 - failpoint.Inject("subtaskHistoryKeepSeconds", func(val failpoint.Value) { - if val, ok := val.(int); ok { - subtaskHistoryKeepSeconds = val - } - }) - _, err := mgr.ExecuteSQLWithNewSession( - ctx, - fmt.Sprintf("DELETE FROM mysql.tidb_background_subtask_history WHERE state_update_time < UNIX_TIMESTAMP() - %d ;", subtaskHistoryKeepSeconds), - ) - return err -} - -// GetManagedNodes implements scheduler.TaskManager interface. -func (mgr *TaskManager) GetManagedNodes(ctx context.Context) ([]proto.ManagedNode, error) { - var nodes []proto.ManagedNode - err := mgr.WithNewSession(func(se sessionctx.Context) error { - var err2 error - nodes, err2 = mgr.getManagedNodesWithSession(ctx, se) - return err2 - }) - return nodes, err -} - -func (mgr *TaskManager) getManagedNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) { - nodes, err := mgr.getAllNodesWithSession(ctx, se) - if err != nil { - return nil, err - } - nodeMap := make(map[string][]proto.ManagedNode, 2) - for _, node := range nodes { - nodeMap[node.Role] = append(nodeMap[node.Role], node) - } - if len(nodeMap["background"]) == 0 { - return nodeMap[""], nil - } - return nodeMap["background"], nil -} - -// GetAllNodes gets nodes in dist_framework_meta. -func (mgr *TaskManager) GetAllNodes(ctx context.Context) ([]proto.ManagedNode, error) { - var nodes []proto.ManagedNode - err := mgr.WithNewSession(func(se sessionctx.Context) error { - var err2 error - nodes, err2 = mgr.getAllNodesWithSession(ctx, se) - return err2 - }) - return nodes, err -} - -func (*TaskManager) getAllNodesWithSession(ctx context.Context, se sessionctx.Context) ([]proto.ManagedNode, error) { - rs, err := sqlexec.ExecSQL(ctx, se, ` - select host, role, cpu_count - from mysql.dist_framework_meta - order by host`) - if err != nil { - return nil, err - } - nodes := make([]proto.ManagedNode, 0, len(rs)) - for _, r := range rs { - nodes = append(nodes, proto.ManagedNode{ - ID: r.GetString(0), - Role: r.GetString(1), - CPUCount: int(r.GetInt64(2)), - }) - } - return nodes, nil -} - -// GetCPUCountOfManagedNode gets the cpu count of managed node. -func (mgr *TaskManager) GetCPUCountOfManagedNode(ctx context.Context) (int, error) { - var cnt int - err := mgr.WithNewSession(func(se sessionctx.Context) error { - var err2 error - cnt, err2 = mgr.getCPUCountOfManagedNode(ctx, se) - return err2 - }) - return cnt, err -} - -// getCPUCountOfManagedNode gets the cpu count of managed node. -// returns error when there's no managed node or no node has valid cpu count. -func (mgr *TaskManager) getCPUCountOfManagedNode(ctx context.Context, se sessionctx.Context) (int, error) { - nodes, err := mgr.getManagedNodesWithSession(ctx, se) - if err != nil { - return 0, err - } - if len(nodes) == 0 { - return 0, errors.New("no managed nodes") - } - var cpuCount int - for _, n := range nodes { - if n.CPUCount > 0 { - cpuCount = n.CPUCount - break - } - } - if cpuCount == 0 { - return 0, errors.New("no managed node have enough resource for dist task") - } - return cpuCount, nil -} diff --git a/pkg/disttask/framework/taskexecutor/interface.go b/pkg/disttask/framework/taskexecutor/interface.go index 7c1e7771c3..4d02106dc8 100644 --- a/pkg/disttask/framework/taskexecutor/interface.go +++ b/pkg/disttask/framework/taskexecutor/interface.go @@ -25,7 +25,8 @@ import ( type TaskTable interface { GetTasksInStates(ctx context.Context, states ...interface{}) (task []*proto.Task, err error) GetTaskByID(ctx context.Context, taskID int64) (task *proto.Task, err error) - GetSubtasksByStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) + // GetSubtasksByExecIDAndStepAndStates gets all subtasks by given states and execID. + GetSubtasksByExecIDAndStepAndStates(ctx context.Context, execID string, taskID int64, step proto.Step, states ...proto.SubtaskState) ([]*proto.Subtask, error) GetFirstSubtaskInStates(ctx context.Context, instanceID string, taskID int64, step proto.Step, states ...proto.SubtaskState) (*proto.Subtask, error) // InitMeta insert the manager information into dist_framework_meta. // Call it when starting task executor or in set variable operation. diff --git a/pkg/disttask/framework/taskexecutor/task_executor.go b/pkg/disttask/framework/taskexecutor/task_executor.go index ff1c53a3d3..4be0bc1020 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor.go +++ b/pkg/disttask/framework/taskexecutor/task_executor.go @@ -111,7 +111,7 @@ func (s *BaseTaskExecutor) checkBalanceSubtask(ctx context.Context) { } task := s.task.Load() - subtasks, err := s.taskTable.GetSubtasksByStepAndStates(ctx, s.id, task.ID, task.Step, + subtasks, err := s.taskTable.GetSubtasksByExecIDAndStepAndStates(ctx, s.id, task.ID, task.Step, proto.SubtaskStateRunning) if err != nil { logutil.Logger(s.logCtx).Error("get subtasks failed", zap.Error(err)) @@ -231,7 +231,7 @@ func (s *BaseTaskExecutor) runStep(ctx context.Context, task *proto.Task) (resEr } }() - subtasks, err := s.taskTable.GetSubtasksByStepAndStates( + subtasks, err := s.taskTable.GetSubtasksByExecIDAndStepAndStates( runCtx, s.id, task.ID, task.Step, proto.SubtaskStatePending, proto.SubtaskStateRunning) if err != nil { diff --git a/pkg/disttask/framework/taskexecutor/task_executor_test.go b/pkg/disttask/framework/taskexecutor/task_executor_test.go index 87abedcff7..c8b326862c 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_test.go @@ -56,7 +56,7 @@ func TestTaskExecutorRun(t *testing.T) { mockExtension.EXPECT().IsRetryableError(gomock.Any()).AnyTimes() // mock for checkBalanceSubtask - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, proto.SubtaskStateRunning).Return([]*proto.Subtask{{ID: 1}}, nil).AnyTimes() task1 := &proto.Task{Step: proto.StepOne, Type: tp} @@ -85,7 +85,7 @@ func TestTaskExecutorRun(t *testing.T) { // 3. run subtask failed runSubtaskErr := errors.New("run subtask error") mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, @@ -102,7 +102,7 @@ func TestTaskExecutorRun(t *testing.T) { // 4. run subtask success mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, @@ -122,7 +122,7 @@ func TestTaskExecutorRun(t *testing.T) { // 5. run subtask one by one mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return( []*proto.Subtask{ {ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}, @@ -157,7 +157,7 @@ func TestTaskExecutorRun(t *testing.T) { // idempotent, so fail it. subtaskID := int64(2) theSubtask := &proto.Subtask{ID: subtaskID, Type: tp, Step: proto.StepOne, State: proto.SubtaskStateRunning, ExecID: "id"} - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{theSubtask}, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, @@ -172,7 +172,7 @@ func TestTaskExecutorRun(t *testing.T) { // run previous left subtask in running state again, but the subtask idempotent, // run it again. theSubtask = &proto.Subtask{ID: subtaskID, Type: tp, Step: proto.StepOne, State: proto.SubtaskStateRunning, ExecID: "id"} - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{theSubtask}, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) // first round of the run loop @@ -192,7 +192,7 @@ func TestTaskExecutorRun(t *testing.T) { require.True(t, ctrl.Satisfied()) // 6. cancel - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) @@ -208,7 +208,7 @@ func TestTaskExecutorRun(t *testing.T) { require.True(t, ctrl.Satisfied()) // 7. RunSubtask return context.Canceled - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) @@ -223,7 +223,7 @@ func TestTaskExecutorRun(t *testing.T) { require.True(t, ctrl.Satisfied()) // 8. grpc cancel - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) @@ -239,7 +239,7 @@ func TestTaskExecutorRun(t *testing.T) { require.True(t, ctrl.Satisfied()) // 9. annotate grpc cancel - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 2, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) @@ -261,7 +261,7 @@ func TestTaskExecutorRun(t *testing.T) { // 10. subtask owned by other executor mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return([]*proto.Subtask{{ ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}}, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, @@ -378,7 +378,7 @@ func TestTaskExecutor(t *testing.T) { mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil).AnyTimes() mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false).AnyTimes() // mock for checkBalanceSubtask - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, proto.SubtaskStateRunning).Return([]*proto.Subtask{{ID: 1}}, nil).AnyTimes() task := &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency} @@ -391,7 +391,7 @@ func TestTaskExecutor(t *testing.T) { subtasks := []*proto.Subtask{ {ID: 1, Type: tp, Step: proto.StepOne, State: proto.SubtaskStatePending, ExecID: "id"}, } - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(subtasks, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(subtasks[0], nil) @@ -416,7 +416,7 @@ func TestTaskExecutor(t *testing.T) { // 3. run one subtask, then task moved to history(ErrTaskNotFound). mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(subtasks, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "id", taskID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(subtasks[0], nil) @@ -449,12 +449,12 @@ func TestRunStepCurrentSubtaskScheduledAway(t *testing.T) { taskExecutor.Extension = mockExtension // mock for checkBalanceSubtask - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, proto.StepOne, proto.SubtaskStateRunning).Return([]*proto.Subtask{}, nil) // mock for runStep mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil) mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", task.ID, proto.StepOne, + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(subtasks, nil) mockSubtaskTable.EXPECT().GetFirstSubtaskInStates(gomock.Any(), "tidb1", task.ID, proto.StepOne, unfinishedNormalSubtaskStates...).Return(subtasks[0], nil) @@ -492,9 +492,9 @@ func TestCheckBalanceSubtask(t *testing.T) { checkBalanceSubtaskInterval = 100 * time.Millisecond // subtask scheduled away - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, task.Step, proto.SubtaskStateRunning).Return(nil, errors.New("error")) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, task.Step, proto.SubtaskStateRunning).Return([]*proto.Subtask{}, nil) runCtx, cancelCause := context.WithCancelCause(ctx) taskExecutor.registerCancelFunc(cancelCause) @@ -505,7 +505,7 @@ func TestCheckBalanceSubtask(t *testing.T) { subtasks := []*proto.Subtask{{ID: 1, ExecID: "tidb1"}} // in-idempotent running subtask - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, task.Step, proto.SubtaskStateRunning).Return(subtasks, nil) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(gomock.Any(), "tidb1", subtasks[0].ID, proto.SubtaskStateFailed, ErrNonIdempotentSubtask).Return(nil) @@ -517,12 +517,12 @@ func TestCheckBalanceSubtask(t *testing.T) { require.Zero(t, taskExecutor.currSubtaskID.Load()) taskExecutor.currSubtaskID.Store(1) subtasks = []*proto.Subtask{{ID: 1, ExecID: "tidb1"}, {ID: 2, ExecID: "tidb1"}} - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, task.Step, proto.SubtaskStateRunning).Return(subtasks, nil) mockExtension.EXPECT().IsIdempotent(gomock.Any()).Return(true) mockSubtaskTable.EXPECT().RunningSubtasksBack2Pending(gomock.Any(), []*proto.Subtask{{ID: 2, ExecID: "tidb1"}}).Return(nil) // used to break the loop - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates(gomock.Any(), "tidb1", + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates(gomock.Any(), "tidb1", task.ID, task.Step, proto.SubtaskStateRunning).Return(nil, nil) taskExecutor.checkBalanceSubtask(ctx) require.True(t, ctrl.Satisfied()) @@ -576,32 +576,32 @@ func TestExecutorErrHandling(t *testing.T) { require.True(t, ctrl.Satisfied()) // 5. GetSubtasksByStepAndStates meet retryable error. - getSubtasksByStepAndStatesErr := errors.New("get subtasks err") + getSubtasksByExecIDAndStepAndStatesErr := errors.New("get subtasks err") mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates( + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates( gomock.Any(), taskExecutor.id, gomock.Any(), proto.StepOne, - unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByStepAndStatesErr) + unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByExecIDAndStepAndStatesErr) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(true) require.NoError(t, taskExecutor.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency})) require.True(t, ctrl.Satisfied()) - // 6. GetSubtasksByStepAndStates meet non retryable error. + // 6. GetSubtasksByExecIDAndStepAndStates meet non retryable error. mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates( + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates( gomock.Any(), taskExecutor.id, gomock.Any(), proto.StepOne, - unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByStepAndStatesErr) + unfinishedNormalSubtaskStates...).Return(nil, getSubtasksByExecIDAndStepAndStatesErr) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) mockExtension.EXPECT().IsRetryableError(gomock.Any()).Return(false) - mockSubtaskTable.EXPECT().FailSubtask(runCtx, taskExecutor.id, gomock.Any(), getSubtasksByStepAndStatesErr) + mockSubtaskTable.EXPECT().FailSubtask(runCtx, taskExecutor.id, gomock.Any(), getSubtasksByExecIDAndStepAndStatesErr) require.NoError(t, taskExecutor.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency})) require.True(t, ctrl.Satisfied()) @@ -609,7 +609,7 @@ func TestExecutorErrHandling(t *testing.T) { cleanupErr := errors.New("cleanup err") mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates( + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates( gomock.Any(), taskExecutor.id, gomock.Any(), @@ -634,7 +634,7 @@ func TestExecutorErrHandling(t *testing.T) { // 8. Cleanup meet non retryable error. mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates( + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates( gomock.Any(), taskExecutor.id, gomock.Any(), @@ -687,7 +687,7 @@ func TestExecutorErrHandling(t *testing.T) { // 13. subtask succeed. mockExtension.EXPECT().GetSubtaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockSubtaskExecutor, nil) mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockSubtaskTable.EXPECT().GetSubtasksByStepAndStates( + mockSubtaskTable.EXPECT().GetSubtasksByExecIDAndStepAndStates( gomock.Any(), taskExecutor.id, gomock.Any(), diff --git a/pkg/disttask/framework/testutil/table_util.go b/pkg/disttask/framework/testutil/table_util.go index 6668d4cccc..1f2d9dfc24 100644 --- a/pkg/disttask/framework/testutil/table_util.go +++ b/pkg/disttask/framework/testutil/table_util.go @@ -17,6 +17,7 @@ package testutil import ( "context" "fmt" + "strings" "testing" "time" @@ -79,6 +80,20 @@ func getTaskManager(t *testing.T, pool *pools.ResourcePool) *storage.TaskManager return manager } +// GetOneTask get a task from task table +func GetOneTask(ctx context.Context, mgr *storage.TaskManager) (task *proto.Task, err error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+storage.TaskColumns+" from mysql.tidb_global_task where state = %? limit 1", proto.TaskStatePending) + if err != nil { + return task, err + } + + if len(rs) == 0 { + return nil, nil + } + + return storage.Row2Task(rs[0]), nil +} + // GetSubtasksFromHistory gets subtasks from history table for test. func GetSubtasksFromHistory(ctx context.Context, mgr *storage.TaskManager) (int, error) { rs, err := mgr.ExecuteSQLWithNewSession(ctx, @@ -191,6 +206,47 @@ func TransferSubTasks2History(ctx context.Context, mgr *storage.TaskManager, tas }) } +// GetTasksFromHistoryInStates gets the tasks in history table in the states. +func GetTasksFromHistoryInStates(ctx context.Context, mgr *storage.TaskManager, states ...interface{}) (task []*proto.Task, err error) { + if len(states) == 0 { + return task, nil + } + + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select "+storage.TaskColumns+" from mysql.tidb_global_task_history where state in ("+strings.Repeat("%?,", len(states)-1)+"%?)", states...) + if err != nil { + return task, err + } + + for _, r := range rs { + task = append(task, storage.Row2Task(r)) + } + return task, nil +} + +// DeleteSubtasksByTaskID deletes the subtask of the given task ID. +func DeleteSubtasksByTaskID(ctx context.Context, mgr *storage.TaskManager, taskID int64) error { + _, err := mgr.ExecuteSQLWithNewSession(ctx, `delete from mysql.tidb_background_subtask + where task_key = %?`, taskID) + if err != nil { + return err + } + + return nil +} + +// IsTaskCancelling checks whether the task state is cancelling. +func IsTaskCancelling(ctx context.Context, mgr *storage.TaskManager, taskID int64) (bool, error) { + rs, err := mgr.ExecuteSQLWithNewSession(ctx, "select 1 from mysql.tidb_global_task where id=%? and state = %?", + taskID, proto.TaskStateCancelling, + ) + + if err != nil { + return false, err + } + + return len(rs) > 0, nil +} + // PrintSubtaskInfo log the subtask info by taskKey for test. func PrintSubtaskInfo(ctx context.Context, mgr *storage.TaskManager, taskID int64) { rs, _ := mgr.ExecuteSQLWithNewSession(ctx, diff --git a/tests/realtikvtest/importintotest/import_into_test.go b/tests/realtikvtest/importintotest/import_into_test.go index e1857dfefa..aa600c3885 100644 --- a/tests/realtikvtest/importintotest/import_into_test.go +++ b/tests/realtikvtest/importintotest/import_into_test.go @@ -769,7 +769,7 @@ func (s *mockGCSSuite) TestColumnsAndUserVars() { taskManager := storage.NewTaskManager(pool) ctx := context.Background() ctx = util.WithInternalSourceType(ctx, "taskManager") - subtasks, err := taskManager.GetSubtasksByStepAndState(ctx, storage.TestLastTaskID.Load(), importinto.StepImport, proto.TaskStateSucceed) + subtasks, err := taskManager.GetAllSubtasksByStepAndState(ctx, storage.TestLastTaskID.Load(), importinto.StepImport, proto.SubtaskStateSucceed) s.NoError(err) s.Len(subtasks, 1) serverInfo, err := infosync.GetServerInfo()