diff --git a/disttask/framework/dispatcher/dispatcher_test.go b/disttask/framework/dispatcher/dispatcher_test.go index bda59da0ed..99923fe1c4 100644 --- a/disttask/framework/dispatcher/dispatcher_test.go +++ b/disttask/framework/dispatcher/dispatcher_test.go @@ -178,16 +178,16 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { // 3s cnt := 60 - checkGetRunningTaskCnt := func() { + checkGetRunningTaskCnt := func(expected int) { var retCnt int for i := 0; i < cnt; i++ { retCnt = dsp.GetRunningTaskCnt() - if retCnt == taskCnt { + if retCnt == expected { break } time.Sleep(time.Millisecond * 50) } - require.Equal(t, retCnt, taskCnt) + require.Equal(t, retCnt, expected) } checkTaskRunningCnt := func() []*proto.Task { @@ -215,7 +215,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { taskIDs = append(taskIDs, taskID) } // test normal flow - checkGetRunningTaskCnt() + checkGetRunningTaskCnt(taskCnt) tasks := checkTaskRunningCnt() for i, taskID := range taskIDs { require.Equal(t, int64(i+1), tasks[i].ID) @@ -227,7 +227,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { if taskCnt == 1 { taskID, err := mgr.AddNewGlobalTask(fmt.Sprintf("%d", taskCnt), taskTypeExample, 0, nil) require.NoError(t, err) - checkGetRunningTaskCnt() + checkGetRunningTaskCnt(taskCnt) // Clean the task. deleteTasks(t, store, taskID) dsp.DelRunningTask(taskID) @@ -254,7 +254,8 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { } checkGetTaskState(proto.TaskStateSucceed) require.Len(t, tasks, taskCnt) - require.Equal(t, 0, dsp.GetRunningTaskCnt()) + + checkGetRunningTaskCnt(0) return } @@ -286,7 +287,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) { } checkGetTaskState(proto.TaskStateReverted) require.Len(t, tasks, taskCnt) - require.Equal(t, 0, dsp.GetRunningTaskCnt()) + checkGetRunningTaskCnt(0) } func TestSimpleNormalFlow(t *testing.T) {