From dbb493ff22094790438416e4bc94987eebcad659 Mon Sep 17 00:00:00 2001 From: EasonBall <592838129@qq.com> Date: Tue, 12 Sep 2023 09:00:09 +0800 Subject: [PATCH] disttask: dynamic dispatch subtasks (#46593) ref pingcap/tidb#46258 --- ddl/backfilling_dispatcher.go | 20 ++- ddl/backfilling_dispatcher_test.go | 36 +++- ddl/index.go | 2 +- ddl/stage_scheduler.go | 4 +- disttask/framework/BUILD.bazel | 3 +- disttask/framework/dispatcher/dispatcher.go | 162 ++++++++++++------ .../framework/dispatcher/dispatcher_test.go | 22 ++- disttask/framework/dispatcher/interface.go | 26 ++- .../framework_dynamic_dispatch_test.go | 114 ++++++++++++ .../framework/framework_err_handling_test.go | 49 +++++- disttask/framework/framework_ha_test.go | 54 +++--- disttask/framework/framework_rollback_test.go | 20 ++- disttask/framework/framework_test.go | 49 ++++-- disttask/framework/proto/task.go | 4 +- disttask/framework/scheduler/manager.go | 2 +- disttask/framework/scheduler/scheduler.go | 21 ++- .../framework/scheduler/scheduler_test.go | 35 +--- disttask/framework/storage/task_table.go | 24 ++- disttask/importinto/dispatcher.go | 58 ++++--- .../importinto/dispatcher_testkit_test.go | 11 +- disttask/importinto/proto.go | 4 +- 21 files changed, 525 insertions(+), 195 deletions(-) create mode 100644 disttask/framework/framework_dynamic_dispatch_test.go diff --git a/ddl/backfilling_dispatcher.go b/ddl/backfilling_dispatcher.go index 7e16672bde..194a6b17ce 100644 --- a/ddl/backfilling_dispatcher.go +++ b/ddl/backfilling_dispatcher.go @@ -50,11 +50,12 @@ func NewBackfillingDispatcherExt(d DDL) (dispatcher.Extension, error) { }, nil } +// OnTick implements dispatcher.Extension interface. func (*backfillingDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -// OnNextStage generate next stage's plan. -func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) { +// OnNextSubtasksBatch generate batch of next stage's plan. +func (h *backfillingDispatcherExt) OnNextSubtasksBatch(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) { var globalTaskMeta BackfillGlobalMeta if err := json.Unmarshal(gTask.Meta, &globalTaskMeta); err != nil { return nil, err @@ -76,7 +77,6 @@ func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher if err != nil { return nil, err } - gTask.Step = proto.StepOne return subTaskMetas, nil } @@ -87,7 +87,6 @@ func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher if err != nil { return nil, err } - gTask.Step = proto.StepOne return subtaskMeta, nil case proto.StepOne: serverNodes, err := dispatcher.GenerateSchedulerNodes(ctx) @@ -103,15 +102,22 @@ func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher for range serverNodes { subTaskMetas = append(subTaskMetas, metaBytes) } - gTask.Step = proto.StepTwo return subTaskMetas, nil - case proto.StepTwo: - return nil, nil default: return nil, nil } } +// StageFinished check if current stage finished. +func (*backfillingDispatcherExt) StageFinished(_ *proto.Task) bool { + return true +} + +// Finished check if current task finished. +func (*backfillingDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepOne +} + // OnErrStage generate error handling stage's plan. func (*backfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, receiveErr []error) (meta []byte, err error) { // We do not need extra meta info when rolling back diff --git a/ddl/backfilling_dispatcher_test.go b/ddl/backfilling_dispatcher_test.go index 94e4ebb487..832b80b34b 100644 --- a/ddl/backfilling_dispatcher_test.go +++ b/ddl/backfilling_dispatcher_test.go @@ -38,7 +38,7 @@ func TestBackfillingDispatcher(t *testing.T) { tk := testkit.NewTestKit(t, store) tk.MustExec("use test") - // test partition table OnNextStage. + /// 1. test partition table. tk.MustExec("create table tp1(id int primary key, v int) PARTITION BY RANGE (id) (\n " + "PARTITION p0 VALUES LESS THAN (10),\n" + "PARTITION p1 VALUES LESS THAN (100),\n" + @@ -48,9 +48,10 @@ func TestBackfillingDispatcher(t *testing.T) { tbl, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("tp1")) require.NoError(t, err) tblInfo := tbl.Meta() - metas, err := dsp.OnNextStage(context.Background(), nil, gTask) + + // 1.1 OnNextSubtasksBatch + metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask) require.NoError(t, err) - require.Equal(t, proto.StepOne, gTask.Step) require.Equal(t, len(tblInfo.Partition.Definitions), len(metas)) for i, par := range tblInfo.Partition.Definitions { var subTask ddl.BackfillSubTaskMeta @@ -58,13 +59,14 @@ func TestBackfillingDispatcher(t *testing.T) { require.Equal(t, par.ID, subTask.PhysicalTableID) } - // test partition table OnNextStage after step1 finished. + // 1.2 test partition table OnNextSubtasksBatch after StepInit finished. gTask.State = proto.TaskStateRunning - metas, err = dsp.OnNextStage(context.Background(), nil, gTask) + gTask.Step++ + metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask) require.NoError(t, err) require.Equal(t, 0, len(metas)) - // test partition table OnErrStage. + // 1.3 test partition table OnErrStage. errMeta, err := dsp.OnErrStage(context.Background(), nil, gTask, []error{errors.New("mockErr")}) require.NoError(t, err) require.Nil(t, errMeta) @@ -73,10 +75,30 @@ func TestBackfillingDispatcher(t *testing.T) { require.NoError(t, err) require.Nil(t, errMeta) + /// 2. test non partition table. + // 2.1 empty table tk.MustExec("create table t1(id int primary key, v int)") gTask = createAddIndexGlobalTask(t, dom, "test", "t1", ddl.BackfillTaskType) - _, err = dsp.OnNextStage(context.Background(), nil, gTask) + metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask) require.NoError(t, err) + require.Equal(t, 0, len(metas)) + // 2.2 non empty table. + tk.MustExec("create table t2(id bigint auto_random primary key)") + tk.MustExec("insert into t2 values (), (), (), (), (), ()") + tk.MustExec("insert into t2 values (), (), (), (), (), ()") + tk.MustExec("insert into t2 values (), (), (), (), (), ()") + tk.MustExec("insert into t2 values (), (), (), (), (), ()") + gTask = createAddIndexGlobalTask(t, dom, "test", "t2", ddl.BackfillTaskType) + // 2.2.1 stepInit + metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask) + require.NoError(t, err) + require.Equal(t, 1, len(metas)) + // 2.2.2 stepOne + gTask.Step++ + gTask.State = proto.TaskStateRunning + metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask) + require.NoError(t, err) + require.Equal(t, 1, len(metas)) } func createAddIndexGlobalTask(t *testing.T, dom *domain.Domain, dbName, tblName string, taskType string) *proto.Task { diff --git a/ddl/index.go b/ddl/index.go index f017572ac9..d8d3898024 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -1942,7 +1942,7 @@ func (w *worker) updateJobRowCount(taskKey string, jobID int64) { logutil.BgLogger().Warn("cannot get global task", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err)) return } - rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepOne) + rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepInit) if err != nil { logutil.BgLogger().Warn("cannot get subtask row count", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err)) return diff --git a/ddl/stage_scheduler.go b/ddl/stage_scheduler.go index e2258ce73b..64f521fbbe 100644 --- a/ddl/stage_scheduler.go +++ b/ddl/stage_scheduler.go @@ -98,9 +98,9 @@ func newBackfillDistScheduler(ctx context.Context, id string, taskID int64, task func (s *backfillDistScheduler) GetSubtaskExecutor(ctx context.Context, task *proto.Task, summary *execute.Summary) (execute.SubtaskExecutor, error) { switch task.Step { - case proto.StepOne: + case proto.StepInit: return NewBackfillSchedulerHandle(ctx, task.Meta, s.d, false, summary) - case proto.StepTwo: + case proto.StepOne: return NewBackfillSchedulerHandle(ctx, task.Meta, s.d, true, nil) default: return nil, errors.Errorf("unknown backfill step %d for task %d", task.Step, task.ID) diff --git a/disttask/framework/BUILD.bazel b/disttask/framework/BUILD.bazel index 9099913c05..b373a76c0b 100644 --- a/disttask/framework/BUILD.bazel +++ b/disttask/framework/BUILD.bazel @@ -4,6 +4,7 @@ go_test( name = "framework_test", timeout = "short", srcs = [ + "framework_dynamic_dispatch_test.go", "framework_err_handling_test.go", "framework_ha_test.go", "framework_rollback_test.go", @@ -11,7 +12,7 @@ go_test( ], flaky = True, race = "off", - shard_count = 24, + shard_count = 26, deps = [ "//disttask/framework/dispatcher", "//disttask/framework/mock", diff --git a/disttask/framework/dispatcher/dispatcher.go b/disttask/framework/dispatcher/dispatcher.go index 48d96be071..98924e904f 100644 --- a/disttask/framework/dispatcher/dispatcher.go +++ b/disttask/framework/dispatcher/dispatcher.go @@ -57,6 +57,8 @@ var ( type TaskHandle interface { // GetPreviousSubtaskMetas gets previous subtask metas. GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error) + // UpdateTask update the task in tidb_global_task table. + UpdateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) error storage.SessionExecutor } @@ -185,7 +187,7 @@ func (d *BaseDispatcher) scheduleTask() { // handle task in cancelling state, dispatch revert subtasks. func (d *BaseDispatcher) onCancelling() error { - logutil.Logger(d.logCtx).Debug("on cancelling state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step)) + logutil.Logger(d.logCtx).Info("on cancelling state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step)) errs := []error{errors.New("cancel")} return d.onErrHandlingStage(errs) } @@ -198,11 +200,10 @@ func (d *BaseDispatcher) onReverting() error { logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err)) return err } - prevStageFinished := cnt == 0 - if prevStageFinished { + if cnt == 0 { // Finish the rollback step. - logutil.Logger(d.logCtx).Info("update the task to reverted state") - return d.updateTask(proto.TaskStateReverted, nil, RetrySQLTimes) + logutil.Logger(d.logCtx).Info("all reverting tasks finished, update the task to reverted state") + return d.UpdateTask(proto.TaskStateReverted, nil, RetrySQLTimes) } // Wait all subtasks in this stage finished. d.OnTick(d.ctx, d.Task) @@ -236,9 +237,19 @@ func (d *BaseDispatcher) onRunning() error { return err } - prevStageFinished := cnt == 0 - if prevStageFinished { - logutil.Logger(d.logCtx).Info("previous stage finished, generate dist plan", zap.Int64("stage", d.Task.Step)) + if cnt == 0 { + logutil.Logger(d.logCtx).Info("previous subtasks finished, generate dist plan", zap.Int64("stage", d.Task.Step)) + // When all subtasks dispatched and processed, mark task as succeed. + if d.Finished(d.Task) { + d.Task.StateUpdateTime = time.Now().UTC() + logutil.Logger(d.logCtx).Info("all subtasks dispatched and processed, finish the task") + err := d.UpdateTask(proto.TaskStateSucceed, nil, RetrySQLTimes) + if err != nil { + logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err)) + return err + } + return nil + } return d.onNextStage() } // Check if any node are down. @@ -309,7 +320,29 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error { return nil } -func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { +func (d *BaseDispatcher) addSubtasks(subtasks []*proto.Subtask) (err error) { + for i := 0; i < RetrySQLTimes; i++ { + err = d.taskMgr.AddSubTasks(d.Task, subtasks) + if err == nil { + break + } + if i%10 == 0 { + logutil.Logger(d.logCtx).Warn("addSubtasks failed", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step), + zap.Int("subtask cnt", len(subtasks)), + zap.Int("retry times", i), zap.Error(err)) + } + time.Sleep(RetrySQLInterval) + } + if err != nil { + logutil.Logger(d.logCtx).Warn("addSubtasks failed", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step), + zap.Int("subtask cnt", len(subtasks)), + zap.Int("retry times", RetrySQLTimes), zap.Error(err)) + } + return err +} + +// UpdateTask update the task in tidb_global_task table. +func (d *BaseDispatcher) UpdateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) { prevState := d.Task.State d.Task.State = taskState if !VerifyTaskStateTransform(prevState, taskState) { @@ -330,7 +363,7 @@ func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subta } if i%10 == 0 { logutil.Logger(d.logCtx).Warn("updateTask first failed", zap.String("from", prevState), zap.String("to", d.Task.State), - zap.Int("retry times", retryTimes), zap.Error(err)) + zap.Int("retry times", i), zap.Error(err)) } time.Sleep(RetrySQLInterval) } @@ -351,11 +384,11 @@ func (d *BaseDispatcher) onErrHandlingStage(receiveErr []error) error { } // 2. dispatch revert dist-plan to EligibleInstances. - return d.dispatchSubTask4Revert(d.Task, meta) + return d.dispatchSubTask4Revert(meta) } -func (d *BaseDispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error { - instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, task) +func (d *BaseDispatcher) dispatchSubTask4Revert(meta []byte) error { + instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, d.Task) if err != nil { logutil.Logger(d.logCtx).Warn("get task's all instances failed", zap.Error(err)) return err @@ -363,54 +396,74 @@ func (d *BaseDispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) e subTasks := make([]*proto.Subtask, 0, len(instanceIDs)) for _, id := range instanceIDs { - subTasks = append(subTasks, proto.NewSubtask(task.ID, task.Type, id, meta)) + subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, id, meta)) } - return d.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) + return d.UpdateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) } func (d *BaseDispatcher) onNextStage() error { - // 1. generate the needed global task meta and subTask meta (dist-plan). - metas, err := d.OnNextStage(d.ctx, d, d.Task) - if err != nil { - return d.handlePlanErr(err) - } - // 2. dispatch dist-plan to EligibleInstances. - return d.dispatchSubTask(d.Task, metas) -} + /// dynamic dispatch subtasks. + failpoint.Inject("mockDynamicDispatchErr", func() { + failpoint.Return(errors.New("mockDynamicDispatchErr")) + }) -func (d *BaseDispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error { - logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.Task.State), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas))) // 1. Adjust the global task's concurrency. - if task.Concurrency == 0 { - task.Concurrency = DefaultSubtaskConcurrency - } - if task.Concurrency > MaxSubtaskConcurrency { - task.Concurrency = MaxSubtaskConcurrency - } - - retryTimes := RetrySQLTimes - // 2. Special handling for the new tasks. - if task.State == proto.TaskStatePending { - // TODO: Consider using TS. - nowTime := time.Now().UTC() - task.StartTime = nowTime - task.StateUpdateTime = nowTime - retryTimes = nonRetrySQLTime - } - - if len(metas) == 0 { - task.StateUpdateTime = time.Now().UTC() - // Write the global task meta into the storage. - err := d.updateTask(proto.TaskStateSucceed, nil, retryTimes) - if err != nil { - logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err)) + if d.Task.State == proto.TaskStatePending { + if d.Task.Concurrency == 0 { + d.Task.Concurrency = DefaultSubtaskConcurrency + } + if d.Task.Concurrency > MaxSubtaskConcurrency { + d.Task.Concurrency = MaxSubtaskConcurrency + } + d.Task.StateUpdateTime = time.Now().UTC() + if err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes); err != nil { + return err + } + } else if d.StageFinished(d.Task) { + // 2. when previous stage finished, update to next stage. + d.Task.Step++ + logutil.Logger(d.logCtx).Info("previous stage finished, run into next stage", zap.Int64("from", d.Task.Step-1), zap.Int64("to", d.Task.Step)) + d.Task.StateUpdateTime = time.Now().UTC() + err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes) + if err != nil { return err } - return nil } - // 3. select all available TiDB nodes for task. - serverNodes, err := d.GetEligibleInstances(d.ctx, task) + for { + // 3. generate a batch of subtasks. + metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task) + if err != nil { + logutil.Logger(d.logCtx).Warn("generate part of subtasks failed", zap.Error(err)) + return d.handlePlanErr(err) + } + + failpoint.Inject("mockDynamicDispatchErr1", func() { + failpoint.Return(errors.New("mockDynamicDispatchErr1")) + }) + + // 4. dispatch batch of subtasks to EligibleInstances. + err = d.dispatchSubTask(metas) + if err != nil { + return err + } + + if d.StageFinished(d.Task) { + break + } + + failpoint.Inject("mockDynamicDispatchErr2", func() { + failpoint.Return(errors.New("mockDynamicDispatchErr2")) + }) + } + return nil +} + +func (d *BaseDispatcher) dispatchSubTask(metas [][]byte) error { + logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas))) + + // select all available TiDB nodes for task. + serverNodes, err := d.GetEligibleInstances(d.ctx, d.Task) logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes))) if err != nil { @@ -434,12 +487,13 @@ func (d *BaseDispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error subTasks := make([]*proto.Subtask, 0, len(metas)) for i, meta := range metas { // we assign the subtask to the instance in a round-robin way. + // TODO: assign the subtask to the instance according to the system load of each nodes pos := i % len(serverNodes) instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port) logutil.Logger(d.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) - subTasks = append(subTasks, proto.NewSubtask(task.ID, task.Type, instanceID, meta)) + subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, instanceID, meta)) } - return d.updateTask(proto.TaskStateRunning, subTasks, RetrySQLTimes) + return d.addSubtasks(subTasks) } func (d *BaseDispatcher) handlePlanErr(err error) error { @@ -449,7 +503,7 @@ func (d *BaseDispatcher) handlePlanErr(err error) error { } d.Task.Error = err // state transform: pending -> failed. - return d.updateTask(proto.TaskStateFailed, nil, RetrySQLTimes) + return d.UpdateTask(proto.TaskStateFailed, nil, RetrySQLTimes) } // GenerateSchedulerNodes generate a eligible TiDB nodes. diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index 38680b4e9a..52a9aefe40 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -48,7 +48,7 @@ type testDispatcherExt struct{} func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { +func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { return nil, nil } @@ -66,12 +66,20 @@ func (*testDispatcherExt) IsRetryableErr(error) bool { return true } +func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool { + return true +} + +func (dsp *testDispatcherExt) Finished(task *proto.Task) bool { + return false +} + type numberExampleDispatcherExt struct{} func (*numberExampleDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (n *numberExampleDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { +func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) { if task.State == proto.TaskStatePending { task.Step = proto.StepInit } @@ -104,6 +112,14 @@ func (*numberExampleDispatcherExt) IsRetryableErr(error) bool { return true } +func (*numberExampleDispatcherExt) StageFinished(task *proto.Task) bool { + return true +} + +func (*numberExampleDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepTwo +} + func MockDispatcherManager(t *testing.T, pool *pools.ResourcePool) (*dispatcher.Manager, *storage.TaskManager) { ctx := context.WithValue(context.Background(), "etcd", true) mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool) @@ -255,7 +271,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { require.NoError(t, err) taskIDs = append(taskIDs, taskID) } - // test OnNextStage. + // test OnNextSubtasksBatch. checkGetRunningTaskCnt(taskCnt) tasks := checkTaskRunningCnt() for i, taskID := range taskIDs { diff --git a/disttask/framework/dispatcher/interface.go b/disttask/framework/dispatcher/interface.go index af021d0daf..758da0c52d 100644 --- a/disttask/framework/dispatcher/interface.go +++ b/disttask/framework/dispatcher/interface.go @@ -31,23 +31,35 @@ type Extension interface { // OnTick is used to handle the ticker event, if business impl need to do some periodical work, you can // do it here, but don't do too much work here, because the ticker interval is small, and it will block // the event is generated every checkTaskRunningInterval, and only when the task NOT FINISHED and NO ERROR. - OnTick(ctx context.Context, gTask *proto.Task) - // OnNextStage is used to move the task to next stage, if returns no error and there's no new subtasks - // the task is finished. + OnTick(ctx context.Context, task *proto.Task) + + // OnNextSubtasksBatch is used to generate batch of subtasks for current stage // NOTE: don't change gTask.State inside, framework will manage it. // it's called when: // 1. task is pending and entering it's first step. - // 2. subtasks of previous step has all finished with no error. - OnNextStage(ctx context.Context, h TaskHandle, gTask *proto.Task) (subtaskMetas [][]byte, err error) + // 2. subtasks dispatched has all finished with no error. + OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task) (subtaskMetas [][]byte, err error) + // OnErrStage is called when: // 1. subtask is finished with error. // 2. task is cancelled after we have dispatched some subtasks. - OnErrStage(ctx context.Context, h TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) + OnErrStage(ctx context.Context, h TaskHandle, task *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) + // GetEligibleInstances is used to get the eligible instances for the task. // on certain condition we may want to use some instances to do the task, such as instances with more disk. - GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error) + GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, error) + // IsRetryableErr is used to check whether the error occurred in dispatcher is retryable. IsRetryableErr(err error) bool + + // StageFinished is used to check if all subtasks in current stage are dispatched and processed. + // StageFinished is called before generating batch of subtasks. + StageFinished(task *proto.Task) bool + + // Finished is used to check if all subtasks for the task are dispatched and processed. + // Finished is called before generating batch of subtasks. + // Once Finished return true, mark the task as succeed. + Finished(task *proto.Task) bool } // FactoryFn is used to create a dispatcher. diff --git a/disttask/framework/framework_dynamic_dispatch_test.go b/disttask/framework/framework_dynamic_dispatch_test.go new file mode 100644 index 0000000000..29f5331e3c --- /dev/null +++ b/disttask/framework/framework_dynamic_dispatch_test.go @@ -0,0 +1,114 @@ +// Copyright 2023 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package framework_test + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/pingcap/failpoint" + "github.com/pingcap/tidb/disttask/framework/dispatcher" + "github.com/pingcap/tidb/disttask/framework/proto" + "github.com/pingcap/tidb/domain/infosync" + "github.com/pingcap/tidb/testkit" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" +) + +type testDynamicDispatcherExt struct { + cnt int +} + +var _ dispatcher.Extension = (*testDynamicDispatcherExt)(nil) + +func (*testDynamicDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {} + +func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { + // step1 + if gTask.Step == proto.StepInit && dsp.cnt < 3 { + dsp.cnt++ + return [][]byte{ + []byte(fmt.Sprintf("task%d", dsp.cnt)), + []byte(fmt.Sprintf("task%d", dsp.cnt)), + }, nil + } + + // step2 + if gTask.Step == proto.StepOne && dsp.cnt < 4 { + dsp.cnt++ + return [][]byte{ + []byte(fmt.Sprintf("task%d", dsp.cnt)), + }, nil + } + return nil, nil +} + +func (*testDynamicDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) { + return nil, nil +} + +func (dsp *testDynamicDispatcherExt) StageFinished(task *proto.Task) bool { + if task.Step == proto.StepInit && dsp.cnt >= 3 { + return true + } + if task.Step == proto.StepOne && dsp.cnt >= 4 { + return true + } + return false +} + +func (dsp *testDynamicDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepOne && dsp.cnt >= 4 +} + +func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { + return generateSchedulerNodes4Test() +} + +func (*testDynamicDispatcherExt) IsRetryableErr(error) bool { + return true +} + +func TestFrameworkDynamicBasic(t *testing.T) { + var m sync.Map + ctrl := gomock.NewController(t) + defer ctrl.Finish() + RegisterTaskMeta(t, ctrl, &m, &testDynamicDispatcherExt{}) + distContext := testkit.NewDistExecutionContext(t, 3) + DispatchTaskAndCheckSuccess("key1", t, &m) + distContext.Close() +} + +func TestFrameworkDynamicHA(t *testing.T) { + var m sync.Map + ctrl := gomock.NewController(t) + defer ctrl.Finish() + RegisterTaskMeta(t, ctrl, &m, &testDynamicDispatcherExt{}) + distContext := testkit.NewDistExecutionContext(t, 3) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr", "5*return()")) + DispatchTaskAndCheckSuccess("key1", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr")) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr1", "5*return()")) + DispatchTaskAndCheckSuccess("key2", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr1")) + + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr2", "5*return()")) + DispatchTaskAndCheckSuccess("key3", t, &m) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr2")) + distContext.Close() +} diff --git a/disttask/framework/framework_err_handling_test.go b/disttask/framework/framework_err_handling_test.go index ffbe25360c..6479c3e7ae 100644 --- a/disttask/framework/framework_err_handling_test.go +++ b/disttask/framework/framework_err_handling_test.go @@ -29,6 +29,7 @@ import ( type planErrDispatcherExt struct { callTime int + cnt int } var ( @@ -39,13 +40,13 @@ var ( func (*planErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (p *planErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { - if gTask.State == proto.TaskStatePending { +func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { + if gTask.Step == proto.StepInit { if p.callTime == 0 { p.callTime++ return nil, errors.New("retryable err") } - gTask.Step = proto.StepOne + p.cnt = 3 return [][]byte{ []byte("task1"), []byte("task2"), @@ -53,7 +54,7 @@ func (p *planErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskH }, nil } if gTask.Step == proto.StepOne { - gTask.Step = proto.StepTwo + p.cnt = 4 return [][]byte{ []byte("task4"), }, nil @@ -77,13 +78,31 @@ func (*planErrDispatcherExt) IsRetryableErr(error) bool { return true } +func (p *planErrDispatcherExt) StageFinished(task *proto.Task) bool { + if task.Step == proto.StepInit && p.cnt == 3 { + return true + } + if task.Step == proto.StepOne && p.cnt == 4 { + return true + } + return false +} + +func (p *planErrDispatcherExt) Finished(task *proto.Task) bool { + if task.Step == proto.StepOne && p.cnt == 4 { + return true + } + return false +} + type planNotRetryableErrDispatcherExt struct { + cnt int } func (*planNotRetryableErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (p *planNotRetryableErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { +func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { return nil, errors.New("not retryable err") } @@ -99,12 +118,25 @@ func (*planNotRetryableErrDispatcherExt) IsRetryableErr(error) bool { return false } +func (p *planNotRetryableErrDispatcherExt) StageFinished(task *proto.Task) bool { + if task.Step == proto.StepInit && p.cnt >= 3 { + return true + } + if task.Step == proto.StepOne && p.cnt >= 4 { + return true + } + return false +} + +func (p *planNotRetryableErrDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepOne && p.cnt >= 4 +} + func TestPlanErr(t *testing.T) { m := sync.Map{} - ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0}) + RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0, 0}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) distContext.Close() @@ -115,7 +147,7 @@ func TestRevertPlanErr(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0}) + RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0, 0}) distContext := testkit.NewDistExecutionContext(t, 2) DispatchTaskAndCheckSuccess("key1", t, &m) distContext.Close() @@ -123,7 +155,6 @@ func TestRevertPlanErr(t *testing.T) { func TestPlanNotRetryableErr(t *testing.T) { m := sync.Map{} - ctrl := gomock.NewController(t) defer ctrl.Finish() RegisterTaskMeta(t, ctrl, &m, &planNotRetryableErrDispatcherExt{}) diff --git a/disttask/framework/framework_ha_test.go b/disttask/framework/framework_ha_test.go index 4084cfe9b2..7cc43d93f2 100644 --- a/disttask/framework/framework_ha_test.go +++ b/disttask/framework/framework_ha_test.go @@ -28,16 +28,18 @@ import ( "go.uber.org/mock/gomock" ) -type haTestFlowHandle struct{} - -var _ dispatcher.Extension = (*haTestFlowHandle)(nil) - -func (*haTestFlowHandle) OnTick(_ context.Context, _ *proto.Task) { +type haTestDispatcherExt struct { + cnt int } -func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { - if gTask.State == proto.TaskStatePending { - gTask.Step = proto.StepOne +var _ dispatcher.Extension = (*haTestDispatcherExt)(nil) + +func (*haTestDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { +} + +func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { + if gTask.Step == proto.StepInit { + dsp.cnt = 10 return [][]byte{ []byte("task1"), []byte("task2"), @@ -52,7 +54,7 @@ func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, }, nil } if gTask.Step == proto.StepOne { - gTask.Step = proto.StepTwo + dsp.cnt = 15 return [][]byte{ []byte("task11"), []byte("task12"), @@ -64,23 +66,37 @@ func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, return nil, nil } -func (*haTestFlowHandle) OnErrStage(ctx context.Context, h dispatcher.TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) { +func (*haTestDispatcherExt) OnErrStage(ctx context.Context, h dispatcher.TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) { return nil, nil } -func (*haTestFlowHandle) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { +func (*haTestDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) { return generateSchedulerNodes4Test() } -func (*haTestFlowHandle) IsRetryableErr(error) bool { +func (*haTestDispatcherExt) IsRetryableErr(error) bool { return true } +func (dsp *haTestDispatcherExt) StageFinished(task *proto.Task) bool { + if task.Step == proto.StepInit && dsp.cnt >= 10 { + return true + } + if task.Step == proto.StepOne && dsp.cnt >= 15 { + return true + } + return false +} + +func (dsp *haTestDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepOne && dsp.cnt >= 15 +} + func TestHABasic(t *testing.T) { var m sync.Map ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{}) + RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 4) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "4*return()")) @@ -94,10 +110,9 @@ func TestHABasic(t *testing.T) { func TestHAManyNodes(t *testing.T) { var m sync.Map - ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{}) + RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 30) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()")) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "30*return()")) @@ -111,10 +126,9 @@ func TestHAManyNodes(t *testing.T) { func TestHAFailInDifferentStage(t *testing.T) { var m sync.Map - ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{}) + RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 6) // stage1 : server num from 6 to 3. // stage2 : server num from 3 to 2. @@ -136,7 +150,7 @@ func TestHAFailInDifferentStageManyNodes(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{}) + RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 30) // stage1 : server num from 30 to 27. // stage2 : server num from 27 to 26. @@ -158,7 +172,7 @@ func TestHAReplacedButRunning(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{}) + RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 4) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume", "10*return(true)")) DispatchTaskAndCheckSuccess("😊", t, &m) @@ -171,7 +185,7 @@ func TestHAReplacedButRunningManyNodes(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{}) + RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{}) distContext := testkit.NewDistExecutionContext(t, 30) require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume", "30*return(true)")) DispatchTaskAndCheckSuccess("😊", t, &m) diff --git a/disttask/framework/framework_rollback_test.go b/disttask/framework/framework_rollback_test.go index 9b962ab26a..2fd29c4c18 100644 --- a/disttask/framework/framework_rollback_test.go +++ b/disttask/framework/framework_rollback_test.go @@ -30,7 +30,9 @@ import ( "go.uber.org/mock/gomock" ) -type rollbackDispatcherExt struct{} +type rollbackDispatcherExt struct { + cnt int +} var _ dispatcher.Extension = (*rollbackDispatcherExt)(nil) var rollbackCnt atomic.Int32 @@ -38,9 +40,9 @@ var rollbackCnt atomic.Int32 func (*rollbackDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (*rollbackDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { - if gTask.State == proto.TaskStatePending { - gTask.Step = proto.StepOne +func (dsp *rollbackDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { + if gTask.Step == proto.StepInit { + dsp.cnt = 3 return [][]byte{ []byte("task1"), []byte("task2"), @@ -62,6 +64,14 @@ func (*rollbackDispatcherExt) IsRetryableErr(error) bool { return true } +func (dsp *rollbackDispatcherExt) StageFinished(task *proto.Task) bool { + return task.Step == proto.StepInit && dsp.cnt >= 3 +} + +func (dsp *rollbackDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepInit && dsp.cnt >= 3 +} + type testRollbackMiniTask struct{} func (testRollbackMiniTask) IsMinimalTask() {} @@ -125,7 +135,7 @@ func TestFrameworkRollback(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskAfterRefreshTask")) }() - DispatchTaskAndCheckState("key2", t, &m, proto.TaskStateReverted) + DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted) require.Equal(t, int32(2), rollbackCnt.Load()) rollbackCnt.Store(0) distContext.Close() diff --git a/disttask/framework/framework_test.go b/disttask/framework/framework_test.go index 6aae0f6e41..916ad4deb8 100644 --- a/disttask/framework/framework_test.go +++ b/disttask/framework/framework_test.go @@ -35,16 +35,18 @@ import ( "go.uber.org/mock/gomock" ) -type testDispatcherExt struct{} +type testDispatcherExt struct { + cnt int +} var _ dispatcher.Extension = (*testDispatcherExt)(nil) func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) { } -func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { - if gTask.State == proto.TaskStatePending { - gTask.Step = proto.StepOne +func (dsp *testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) { + if gTask.Step == proto.StepInit { + dsp.cnt = 3 return [][]byte{ []byte("task1"), []byte("task2"), @@ -52,7 +54,7 @@ func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle }, nil } if gTask.Step == proto.StepOne { - gTask.Step = proto.StepTwo + dsp.cnt = 4 return [][]byte{ []byte("task4"), }, nil @@ -64,6 +66,20 @@ func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, return nil, nil } +func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool { + if task.Step == proto.StepInit && dsp.cnt >= 3 { + return true + } + if task.Step == proto.StepOne && dsp.cnt >= 4 { + return true + } + return false +} + +func (dsp *testDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == proto.StepOne && dsp.cnt >= 4 +} + func generateSchedulerNodes4Test() ([]*infosync.ServerInfo, error) { serverInfos := infosync.MockGlobalServerInfoManagerEntry.GetAllServerInfo() if len(serverInfos) == 0 { @@ -174,9 +190,9 @@ func RegisterTaskMeta(t *testing.T, ctrl *gomock.Controller, m *sync.Map, dispat mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) { switch step { - case proto.StepOne: + case proto.StepInit: return &testSubtaskExecutor{m: m}, nil - case proto.StepTwo: + case proto.StepOne: return &testSubtaskExecutor1{m: m}, nil } panic("invalid step") @@ -212,9 +228,9 @@ func RegisterTaskMetaForExample2(t *testing.T, ctrl *gomock.Controller, m *sync. mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) { switch step { - case proto.StepOne: + case proto.StepInit: return &testSubtaskExecutor2{m: m}, nil - case proto.StepTwo: + case proto.StepOne: return &testSubtaskExecutor3{m: m}, nil } panic("invalid step") @@ -223,6 +239,10 @@ func RegisterTaskMetaForExample2(t *testing.T, ctrl *gomock.Controller, m *sync. } func RegisterTaskMetaForExample2Inner(t *testing.T, mockExtension scheduler.Extension, dispatcherHandle dispatcher.Extension) { + t.Cleanup(func() { + dispatcher.ClearDispatcherFactory() + scheduler.ClearSchedulers() + }) dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample2, func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) @@ -244,9 +264,9 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync. mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) { switch step { - case proto.StepOne: + case proto.StepInit: return &testSubtaskExecutor4{m: m}, nil - case proto.StepTwo: + case proto.StepOne: return &testSubtaskExecutor5{m: m}, nil } panic("invalid step") @@ -255,6 +275,10 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync. } func RegisterTaskMetaForExample3Inner(t *testing.T, mockExtension scheduler.Extension, dispatcherHandle dispatcher.Extension) { + t.Cleanup(func() { + dispatcher.ClearDispatcherFactory() + scheduler.ClearSchedulers() + }) dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample3, func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher { baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task) @@ -500,6 +524,7 @@ func TestFrameworkSubTaskFailed(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockExecutorRunErr")) }() DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted) + distContext.Close() } @@ -563,7 +588,6 @@ func TestSchedulerDownBasic(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) - distContext.Close() } @@ -582,7 +606,6 @@ func TestSchedulerDownManyNodes(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown")) require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager")) - distContext.Close() } diff --git a/disttask/framework/proto/task.go b/disttask/framework/proto/task.go index 5996e61a90..c1ab89f1a4 100644 --- a/disttask/framework/proto/task.go +++ b/disttask/framework/proto/task.go @@ -48,7 +48,7 @@ const ( // TaskStep is the step of task. const ( - StepInit int64 = -1 + StepInit int64 = 0 StepOne int64 = 1 StepTwo int64 = 2 ) @@ -56,7 +56,7 @@ const ( // TaskIDLabelName is the label name of task id. const TaskIDLabelName = "task_id" -// Task represents the task of distribute framework. +// Task represents the task of distributed framework. type Task struct { ID int64 Key string diff --git a/disttask/framework/scheduler/manager.go b/disttask/framework/scheduler/manager.go index 8546bac07f..65298401c4 100644 --- a/disttask/framework/scheduler/manager.go +++ b/disttask/framework/scheduler/manager.go @@ -301,8 +301,8 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str v, ok := testContexts.Load(m.id) if ok { <-v.(*TestContext).TestSyncSubtaskRun - m.Stop() _ = infosync.MockGlobalServerInfoManagerEntry.DeleteByID(m.id) + m.Stop() } }() }) diff --git a/disttask/framework/scheduler/scheduler.go b/disttask/framework/scheduler/scheduler.go index 6630a07242..f290bfdefc 100644 --- a/disttask/framework/scheduler/scheduler.go +++ b/disttask/framework/scheduler/scheduler.go @@ -84,7 +84,6 @@ func (s *BaseScheduler) startCancelCheck(ctx context.Context, wg *sync.WaitGroup return case <-ticker.C: canceled, err := s.taskTable.IsSchedulerCanceled(s.taskID, s.id) - logutil.Logger(s.logCtx).Info("scheduler before canceled") if err != nil { continue } @@ -170,15 +169,25 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error { subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, task.Step, proto.TaskStatePending) if err != nil { - s.onError(err) - break + logutil.Logger(s.logCtx).Warn("GetSubtaskInStates meets error", zap.Error(err)) + continue } if subtask == nil { - break + newTask, err := s.taskTable.GetGlobalTaskByID(task.ID) + if err != nil { + logutil.Logger(s.logCtx).Warn("GetGlobalTaskByID meets error", zap.Error(err)) + continue + } + // When the task move to next step or task state changes, the scheduler should exit. + if newTask.Step != task.Step || newTask.State != task.State { + break + } + continue } s.startSubtask(subtask.ID) if err := s.getError(); err != nil { - break + logutil.Logger(s.logCtx).Warn("startSubtask meets error", zap.Error(err)) + continue } failpoint.Inject("mockCleanScheduler", func() { v, ok := testContexts.Load(s.id) @@ -216,11 +225,13 @@ func (s *BaseScheduler) runSubtask(ctx context.Context, scheduler execute.Subtas zap.Int64("subtask_step", subtask.Step)) failpoint.Inject("mockTiDBDown", func(val failpoint.Value) { + logutil.Logger(s.logCtx).Info("trigger mockTiDBDown") if s.id == val.(string) || s.id == ":4001" || s.id == ":4002" { v, ok := testContexts.Load(s.id) if ok { v.(*TestContext).TestSyncSubtaskRun <- struct{}{} v.(*TestContext).mockDown.Store(true) + logutil.Logger(s.logCtx).Info("mockTiDBDown") time.Sleep(2 * time.Second) failpoint.Return() } diff --git a/disttask/framework/scheduler/scheduler_test.go b/disttask/framework/scheduler/scheduler_test.go index 8eea24ac64..2ab1a31798 100644 --- a/disttask/framework/scheduler/scheduler_test.go +++ b/disttask/framework/scheduler/scheduler_test.go @@ -108,28 +108,9 @@ func TestSchedulerRun(t *testing.T) { err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp}) require.EqualError(t, err, poolErr.Error()) - // 4. get subtask failed - getSubtaskErr := errors.New("get subtask error") - cleanupErr := errors.New("clean up error") var taskID int64 = 1 - mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn) - mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(nil, getSubtaskErr) - mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(cleanupErr) - err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) - require.EqualError(t, err, getSubtaskErr.Error()) - // 5. update subtask state failed - updateSubtaskErr := errors.New("update subtask error") - mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) - mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn) - mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Step: proto.StepOne}, nil) - mockSubtaskTable.EXPECT().StartSubtask(taskID).Return(updateSubtaskErr) - mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) - err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID}) - require.EqualError(t, err, updateSubtaskErr.Error()) - - // 6. no subtask executor constructor + // 4. no subtask executor constructor subtaskExecutorRegisterErr := errors.Errorf("constructor of subtask executor for key not found") mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, subtaskExecutorRegisterErr) var concurrency uint64 = 10 @@ -145,7 +126,7 @@ func TestSchedulerRun(t *testing.T) { mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockMiniTaskExecutor, nil).AnyTimes() - // 7. run subtask failed + // 5. run subtask failed runSubtaskErr := errors.New("run subtask error") mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn) @@ -154,11 +135,12 @@ func TestSchedulerRun(t *testing.T) { mockSubtaskExecutor.EXPECT().SplitSubtask(gomock.Any(), gomock.Any()).Return([]proto.MinimalTask{MockMinimalTask{}}, nil) mockMiniTaskExecutor.EXPECT().Run(gomock.Any()).Return(runSubtaskErr) mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(taskID, proto.TaskStateFailed, gomock.Any()).Return(nil) + mockSubtaskTable.EXPECT().GetGlobalTaskByID(gomock.Any()).Return(&proto.Task{ID: taskID, Step: proto.StepTwo}, nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}) require.EqualError(t, err, runSubtaskErr.Error()) - // 8. run subtask success + // 6. run subtask success mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn) mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Type: tp, Step: proto.StepOne}, nil) @@ -168,11 +150,12 @@ func TestSchedulerRun(t *testing.T) { mockSubtaskExecutor.EXPECT().OnFinished(gomock.Any(), gomock.Any()).Return([]byte(""), nil) mockSubtaskTable.EXPECT().FinishSubtask(int64(1), gomock.Any()).Return(nil) mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(nil, nil) + mockSubtaskTable.EXPECT().GetGlobalTaskByID(gomock.Any()).Return(&proto.Task{ID: taskID, Step: proto.StepTwo}, nil) mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}) require.NoError(t, err) - // 9. run subtask one by one + // 7. run subtask one by one mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn) mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Type: tp, Step: proto.StepOne}, nil) @@ -193,8 +176,7 @@ func TestSchedulerRun(t *testing.T) { mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil) err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency}) require.NoError(t, err) - - // 10. cancel + // 8. cancel mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil) mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn) mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Type: tp, Step: proto.StepOne}, nil) @@ -215,8 +197,7 @@ func TestSchedulerRun(t *testing.T) { time.Sleep(time.Second) runCancel() wg.Wait() - - // 11. run subtask one by one, on error, we should wait all minimal task finished before call Cleanup + // 9. run subtask one by one, on error, we should wait all minimal task finished before call Cleanup syncCh := make(chan struct{}) lastMinimalTaskFinishTime, cleanupTime := time.Time{}, time.Time{} require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/waitUntilError", `return(true)`)) diff --git a/disttask/framework/storage/task_table.go b/disttask/framework/storage/task_table.go index b2d8b58e6e..61cbb5a38c 100644 --- a/disttask/framework/storage/task_table.go +++ b/disttask/framework/storage/task_table.go @@ -350,7 +350,7 @@ func (stm *TaskManager) UpdateErrorToSubtask(tidbID string, taskID int64, err er } // PrintSubtaskInfo log the subtask info by taskKey. -func (stm *TaskManager) PrintSubtaskInfo(taskKey int) { +func (stm *TaskManager) PrintSubtaskInfo(taskKey int64) { rs, _ := stm.executeSQLWithNewSession(stm.ctx, "select * from mysql.tidb_background_subtask where task_key = %?", taskKey) @@ -609,6 +609,27 @@ func (stm *TaskManager) UpdateFailedSchedulerIDs(taskID int64, replaceNodes map[ return err } +// AddSubTasks add new batch of subtasks. +func (stm *TaskManager) AddSubTasks(task *proto.Task, subtasks []*proto.Subtask) error { + err := stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error { + for _, subtask := range subtasks { + subtaskState := proto.TaskStatePending + if task.State == proto.TaskStateReverting { + subtaskState = proto.TaskStateRevertPending + } + _, err := ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask + (step, task_key, exec_id, meta, state, type, checkpoint, summary) + values (%?, %?, %?, %?, %?, %?, %?, %?)`, + task.Step, task.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}") + if err != nil { + return err + } + } + return nil + }) + return err +} + // UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) (bool, error) { retryable := true @@ -618,7 +639,6 @@ func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtas if err != nil { return err } - if se.GetSessionVars().StmtCtx.AffectedRows() == 0 { retryable = false return errors.New("invalid task state transform, state already changed") diff --git a/disttask/importinto/dispatcher.go b/disttask/importinto/dispatcher.go index af78eec78e..0da29143cd 100644 --- a/disttask/importinto/dispatcher.go +++ b/disttask/importinto/dispatcher.go @@ -191,8 +191,8 @@ func (dsp *ImportDispatcherExt) unregisterTask(ctx context.Context, task *proto. } } -// OnNextStage implements dispatcher.Extension interface. -func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) ( +// OnNextSubtasksBatch generate batch of next stage's plan. +func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(ctx context.Context, taskHandle dispatcher.TaskHandle, gTask *proto.Task) ( resSubtaskMeta [][]byte, err error) { logger := logutil.BgLogger().With( zap.String("type", gTask.Type), @@ -204,18 +204,18 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch if err != nil { return nil, err } - logger.Info("on next stage") + logger.Info("on next subtasks batch") defer func() { // currently, framework will take the task as finished when err is not nil or resSubtaskMeta is empty. taskFinished := err == nil && len(resSubtaskMeta) == 0 if taskFinished { // todo: we're not running in a transaction with task update - if err2 := dsp.finishJob(ctx, logger, handle, gTask, taskMeta); err2 != nil { + if err2 := dsp.finishJob(ctx, logger, taskHandle, gTask, taskMeta); err2 != nil { err = err2 } } else if err != nil && !dsp.IsRetryableErr(err) { - if err2 := dsp.failJob(ctx, handle, gTask, taskMeta, logger, err.Error()); err2 != nil { + if err2 := dsp.failJob(ctx, taskHandle, gTask, taskMeta, logger, err.Error()); err2 != nil { // todo: we're not running in a transaction with task update, there might be case // failJob return error, but task update succeed. logger.Error("call failJob failed", zap.Error(err2)) @@ -223,22 +223,18 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch } }() - var nextStep int64 switch gTask.Step { - case proto.StepInit: + case StepImport: if metrics, ok := metric.GetCommonMetric(ctx); ok { metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(taskMeta.Plan.TotalFileSize)) } - - if err := preProcess(ctx, handle, gTask, taskMeta, logger); err != nil { + if err := preProcess(ctx, taskHandle, gTask, taskMeta, logger); err != nil { return nil, err } - if err = startJob(ctx, logger, handle, taskMeta); err != nil { + if err = startJob(ctx, logger, taskHandle, taskMeta); err != nil { return nil, err } - logger.Info("move to import step") - nextStep = StepImport - case StepImport: + case StepPostProcess: dsp.switchTiKV2NormalMode(ctx, gTask, logger) failpoint.Inject("clearLastSwitchTime", func() { dsp.lastSwitchTime.Store(time.Time{}) @@ -249,18 +245,20 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch failpoint.Inject("failWhenDispatchPostProcessSubtask", func() { failpoint.Return(nil, errors.New("injected error after StepImport")) }) - if err := updateResult(handle, gTask, taskMeta); err != nil { + if err := updateResult(taskHandle, gTask, taskMeta); err != nil { + return nil, err + } + if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil { return nil, err } logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result)) - nextStep = StepPostProcess - case StepPostProcess: + case StepPostProcess + 1: return nil, nil default: return nil, errors.Errorf("unknown step %d", gTask.Step) } - previousSubtaskMetas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step) + previousSubtaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step-1) if err != nil { return nil, err } @@ -273,12 +271,11 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch if err != nil { return nil, err } - metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, nextStep) + metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, gTask.Step) if err != nil { return nil, err } - gTask.Step = nextStep - logger.Info("generate subtasks", zap.Int64("step", nextStep), zap.Int("subtask-count", len(metaBytes))) + logger.Info("generate subtasks", zap.Int("subtask-count", len(metaBytes))) return metaBytes, nil } @@ -340,6 +337,16 @@ func (*ImportDispatcherExt) IsRetryableErr(error) bool { return false } +// StageFinished check if current stage finished. +func (*ImportDispatcherExt) StageFinished(_ *proto.Task) bool { + return true +} + +// Finished check if current task finished. +func (*ImportDispatcherExt) Finished(task *proto.Task) bool { + return task.Step == StepPostProcess+1 +} + func (dsp *ImportDispatcherExt) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) { dsp.updateCurrentTask(task) if dsp.disableTiKVImportMode.Load() { @@ -475,6 +482,7 @@ func updateMeta(gTask *proto.Task, taskMeta *TaskMeta) error { return err } gTask.Meta = bs + return nil } @@ -492,7 +500,7 @@ func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[i // we will update taskMeta in place and make gTask.Meta point to the new taskMeta. func updateResult(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { - metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step) + metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step-1) if err != nil { return err } @@ -564,6 +572,9 @@ func (dsp *ImportDispatcherExt) finishJob(ctx context.Context, logger *zap.Logge taskHandle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error { dsp.unregisterTask(ctx, gTask) redactSensitiveInfo(gTask, taskMeta) + if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil { + return err + } summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt} // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) @@ -582,6 +593,9 @@ func (dsp *ImportDispatcherExt) failJob(ctx context.Context, taskHandle dispatch dsp.switchTiKV2NormalMode(ctx, gTask, logger) dsp.unregisterTask(ctx, gTask) redactSensitiveInfo(gTask, taskMeta) + if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil { + return err + } // retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval) return handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger, @@ -632,8 +646,6 @@ func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Ta func stepStr(step int64) string { switch step { - case proto.StepInit: - return "init" case StepImport: return "import" case StepPostProcess: diff --git a/disttask/importinto/dispatcher_testkit_test.go b/disttask/importinto/dispatcher_testkit_test.go index e8d90eff29..83d43363e3 100644 --- a/disttask/importinto/dispatcher_testkit_test.go +++ b/disttask/importinto/dispatcher_testkit_test.go @@ -89,7 +89,7 @@ func TestDispatcherExt(t *testing.T) { // to import stage, job should be running d := dsp.MockDispatcher(task) ext := importinto.ImportDispatcherExt{} - subtaskMetas, err := ext.OnNextStage(ctx, d, task) + subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task) require.NoError(t, err) require.Len(t, subtaskMetas, 1) require.Equal(t, importinto.StepImport, task.Step) @@ -109,7 +109,8 @@ func TestDispatcherExt(t *testing.T) { require.NoError(t, manager.FinishSubtask(s.ID, []byte("{}"))) } // to post-process stage, job should be running and in validating step - subtaskMetas, err = ext.OnNextStage(ctx, d, task) + task.Step++ + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task) require.NoError(t, err) require.Len(t, subtaskMetas, 1) require.Equal(t, importinto.StepPostProcess, task.Step) @@ -118,7 +119,8 @@ func TestDispatcherExt(t *testing.T) { require.Equal(t, "running", gotJobInfo.Status) require.Equal(t, "validating", gotJobInfo.Step) // on next stage, job should be finished - subtaskMetas, err = ext.OnNextStage(ctx, d, task) + task.Step++ + subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task) require.NoError(t, err) require.Len(t, subtaskMetas, 0) gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true) @@ -133,7 +135,8 @@ func TestDispatcherExt(t *testing.T) { bs, err = logicalPlan.ToTaskMeta() require.NoError(t, err) task.Meta = bs - task.Step = proto.StepInit + // Set step to StepPostProcess to skip the rollback sql. + task.Step = importinto.StepPostProcess require.NoError(t, importer.StartJob(ctx, conn, jobID)) _, err = ext.OnErrStage(ctx, d, task, []error{errors.New("test")}) require.NoError(t, err) diff --git a/disttask/importinto/proto.go b/disttask/importinto/proto.go index eb4dbe742b..970bd981b3 100644 --- a/disttask/importinto/proto.go +++ b/disttask/importinto/proto.go @@ -33,10 +33,10 @@ import ( // steps are processed in the following order: StepInit -> StepImport -> StepPostProcess const ( // StepImport we sort source data and ingest it into TiKV in this step. - StepImport int64 = 1 + StepImport int64 = 0 // StepPostProcess we verify checksum and add index in this step. // TODO: Might split into StepValidate and StepAddIndex later. - StepPostProcess int64 = 2 + StepPostProcess int64 = 1 ) // TaskMeta is the task of IMPORT INTO.