diff --git a/pkg/disttask/framework/BUILD.bazel b/pkg/disttask/framework/BUILD.bazel index 067cac82e1..a73a7a3a7d 100644 --- a/pkg/disttask/framework/BUILD.bazel +++ b/pkg/disttask/framework/BUILD.bazel @@ -13,7 +13,7 @@ go_test( ], flaky = True, race = "off", - shard_count = 31, + shard_count = 32, deps = [ "//pkg/disttask/framework/dispatcher", "//pkg/disttask/framework/handle", diff --git a/pkg/disttask/framework/dispatcher/dispatcher.go b/pkg/disttask/framework/dispatcher/dispatcher.go index ebb7bfa86b..4afbd19675 100644 --- a/pkg/disttask/framework/dispatcher/dispatcher.go +++ b/pkg/disttask/framework/dispatcher/dispatcher.go @@ -480,16 +480,20 @@ func (d *BaseDispatcher) onErrHandlingStage(receiveErrs []error) error { } func (d *BaseDispatcher) dispatchSubTask4Revert(meta []byte) error { - instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, d.Task) - if err != nil { - logutil.Logger(d.logCtx).Warn("get task's all instances failed", zap.Error(err)) - return err - } + var subTasks []*proto.Subtask + // when step of task is `StepInit`, no need to do revert + if d.Task.Step != proto.StepInit { + instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, d.Task) + if err != nil { + logutil.Logger(d.logCtx).Warn("get task's all instances failed", zap.Error(err)) + return err + } - subTasks := make([]*proto.Subtask, 0, len(instanceIDs)) - for _, id := range instanceIDs { - // reverting subtasks belong to the same step as current active step. - subTasks = append(subTasks, proto.NewSubtask(d.Task.Step, d.Task.ID, d.Task.Type, id, meta)) + subTasks = make([]*proto.Subtask, 0, len(instanceIDs)) + for _, id := range instanceIDs { + // reverting subtasks belong to the same step as current active step. + subTasks = append(subTasks, proto.NewSubtask(d.Task.Step, d.Task.ID, d.Task.Type, id, meta)) + } } return d.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes) } @@ -615,6 +619,10 @@ func (d *BaseDispatcher) dispatchSubTask( logutil.Logger(d.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID)) subTasks = append(subTasks, proto.NewSubtask(subtaskStep, d.Task.ID, d.Task.Type, instanceID, meta)) } + failpoint.Inject("cancelBeforeUpdateTask", func() { + _ = d.updateTask(proto.TaskStateCancelling, subTasks, RetrySQLTimes) + }) + return d.updateTask(d.Task.State, subTasks, RetrySQLTimes) } diff --git a/pkg/disttask/framework/framework_test.go b/pkg/disttask/framework/framework_test.go index e9c5f59660..cc3a942448 100644 --- a/pkg/disttask/framework/framework_test.go +++ b/pkg/disttask/framework/framework_test.go @@ -728,3 +728,18 @@ func TestFrameworkCleanUpRoutine(t *testing.T) { require.NotEmpty(t, tasks) distContext.Close() } + +func TestTaskCancelledBeforeUpdateTask(t *testing.T) { + var m sync.Map + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ctx := context.Background() + ctx = util.WithInternalSourceType(ctx, "dispatcher") + + RegisterTaskMeta(t, ctrl, &m, &testDispatcherExt{}) + distContext := testkit.NewDistExecutionContext(t, 1) + require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/pkg/disttask/framework/dispatcher/cancelBeforeUpdateTask", "1*return(true)")) + DispatchTaskAndCheckState(ctx, "key1", t, &m, proto.TaskStateReverted) + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/disttask/framework/dispatcher/cancelBeforeUpdateTask")) + distContext.Close() +}