// Copyright 2024 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package integrationtests import ( "context" "sync" "sync/atomic" "testing" "time" "github.com/pingcap/tidb/pkg/disttask/framework/handle" mockexecute "github.com/pingcap/tidb/pkg/disttask/framework/mock/execute" "github.com/pingcap/tidb/pkg/disttask/framework/proto" "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/execute" "github.com/pingcap/tidb/pkg/disttask/framework/testutil" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/testkit/testfailpoint" "github.com/pingcap/tidb/pkg/util/sqlexec" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" ) type collectedRuntimeInfo struct { currTaskMeta atomic.Pointer[[]byte] currTaskConcurrency atomic.Int64 activeSubtaskCount atomic.Int64 subtaskInfos []subtaskRuntimeInfo } type subtaskRuntimeInfo struct { Step proto.Step Concurrency int } func prepareModifyTaskTest(t *testing.T, nodeCount int) (*testutil.TestDXFContext, *collectedRuntimeInfo, chan struct{}, chan struct{}, *atomic.Bool) { c := testutil.NewTestDXFContext(t, nodeCount, 16, true) stepInfos := []testutil.StepInfo{ {Step: proto.StepOne, SubtaskCnt: 2}, {Step: proto.StepTwo, SubtaskCnt: 3}, } schedulerExt := testutil.GetMockSchedulerExt(c.MockCtrl, testutil.SchedulerInfo{ AllErrorRetryable: true, StepInfos: stepInfos, }) subtaskCh := make(chan struct{}) runtimeInfo := &collectedRuntimeInfo{} var ( testModifyWhenSubtaskRun atomic.Bool modifyWaitCh = make(chan struct{}) ) var mu sync.Mutex runSubtaskFn := func(ctx context.Context, subtask *proto.Subtask) error { runtimeInfo.activeSubtaskCount.Add(1) defer runtimeInfo.activeSubtaskCount.Add(-1) if testModifyWhenSubtaskRun.Load() { <-modifyWaitCh <-modifyWaitCh } select { case <-subtaskCh: mu.Lock() runtimeInfo.subtaskInfos = append(runtimeInfo.subtaskInfos, subtaskRuntimeInfo{ Step: subtask.Step, Concurrency: subtask.Concurrency, }) mu.Unlock() case <-ctx.Done(): return ctx.Err() } return nil } executorExt := testutil.GetCommonTaskExecutorExt(c.MockCtrl, func(task *proto.Task) (execute.StepExecutor, error) { runtimeInfo.currTaskMeta.Store(&task.Meta) runtimeInfo.currTaskConcurrency.Store(int64(task.Concurrency)) executor := mockexecute.NewMockStepExecutor(c.MockCtrl) executor.EXPECT().Init(gomock.Any()).Return(nil).AnyTimes() executor.EXPECT().RunSubtask(gomock.Any(), gomock.Any()).DoAndReturn(runSubtaskFn).AnyTimes() executor.EXPECT().GetStep().Return(task.Step).AnyTimes() executor.EXPECT().SetResource(gomock.Any()).AnyTimes() executor.EXPECT().Cleanup(gomock.Any()).Return(nil).AnyTimes() executor.EXPECT().RealtimeSummary().Return(nil).AnyTimes() executor.EXPECT().TaskMetaModified(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, newMeta []byte) error { runtimeInfo.currTaskMeta.Store(&newMeta) return nil }).AnyTimes() executor.EXPECT().ResourceModified(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, newResource *proto.StepResource) error { runtimeInfo.currTaskConcurrency.Store(newResource.CPU.Capacity()) return nil }).AnyTimes() return executor, nil }) testutil.RegisterExampleTask(t, schedulerExt, executorExt, testutil.GetCommonCleanUpRoutine(c.MockCtrl)) return c, runtimeInfo, subtaskCh, modifyWaitCh, &testModifyWhenSubtaskRun } func TestModifyTaskConcurrencyAndMeta(t *testing.T) { c, runtimeInfo, subtaskCh, modifyWaitCh, testModifyWhenSubtaskRun := prepareModifyTaskTest(t, 1) resetRuntimeInfoFn := func() { *runtimeInfo = collectedRuntimeInfo{} } scope := handle.GetTargetScope() t.Run("modify pending task concurrency", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once modifySyncCh := make(chan struct{}) var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { task, err := handle.SubmitTask(c.Ctx, "k1", proto.TaskTypeExample, "", 3, scope, 0, []byte("init")) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStatePending, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 7}, }, })) theTask = task gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID) require.NoError(t, err) require.Equal(t, proto.TaskStateModifying, gotTask.State) require.Equal(t, 3, gotTask.Concurrency) <-modifySyncCh }) }) modifySyncCh <- struct{}{} // finish subtasks for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.EqualValues(t, []subtaskRuntimeInfo{ {Step: proto.StepOne, Concurrency: 7}, {Step: proto.StepOne, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, }, runtimeInfo.subtaskInfos) }) t.Run("modify running task concurrency at step two", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once modifySyncCh := make(chan struct{}) testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) { if task.State != proto.TaskStateRunning && task.Step != proto.StepTwo { return } once.Do(func() { require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStateRunning, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 7}, }, })) <-modifySyncCh }) }) task, err := handle.SubmitTask(c.Ctx, "k2", proto.TaskTypeExample, "", 3, scope, 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) // finish StepOne subtaskCh <- struct{}{} subtaskCh <- struct{}{} // wait task move to 'modifying' state modifySyncCh <- struct{}{} // wait task move back to 'running' state require.Eventually(t, func() bool { gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID) require.NoError(t, err2) return gotTask.State == proto.TaskStateRunning }, 10*time.Second, 100*time.Millisecond) // finish StepTwo subtaskCh <- struct{}{} subtaskCh <- struct{}{} subtaskCh <- struct{}{} task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.EqualValues(t, []subtaskRuntimeInfo{ {Step: proto.StepOne, Concurrency: 3}, {Step: proto.StepOne, Concurrency: 3}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, }, runtimeInfo.subtaskInfos) }) t.Run("modify running task concurrency at second subtask of step two", func(t *testing.T) { defer resetRuntimeInfoFn() testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/taskexecutor/beforeGetTaskByIDInRun", func(taskID int64) { if len(runtimeInfo.subtaskInfos) == 3 { require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, taskID, &proto.ModifyParam{ PrevState: proto.TaskStateRunning, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 7}, }, })) // wait task move back to 'running' state require.Eventually(t, func() bool { gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, taskID) require.NoError(t, err2) return gotTask.State == proto.TaskStateRunning }, 10*time.Second, 100*time.Millisecond) } }, ) task, err := handle.SubmitTask(c.Ctx, "k2-2", proto.TaskTypeExample, "", 3, scope, 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.EqualValues(t, []subtaskRuntimeInfo{ {Step: proto.StepOne, Concurrency: 3}, {Step: proto.StepOne, Concurrency: 3}, {Step: proto.StepTwo, Concurrency: 3}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, }, runtimeInfo.subtaskInfos) }) t.Run("modify paused task concurrency", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once syncCh := make(chan struct{}) var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { task, err := handle.SubmitTask(c.Ctx, "k3", proto.TaskTypeExample, "", 3, scope, 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) found, err := c.TaskMgr.PauseTask(c.Ctx, task.Key) require.NoError(t, err) require.True(t, found) theTask = task <-syncCh }) }) syncCh <- struct{}{} taskBase := testutil.WaitTaskDoneOrPaused(c.Ctx, t, theTask.Key) require.Equal(t, proto.TaskStatePaused, taskBase.State) require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, theTask.ID, &proto.ModifyParam{ PrevState: proto.TaskStatePaused, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 7}, }, })) taskBase = testutil.WaitTaskDoneOrPaused(c.Ctx, t, theTask.Key) require.Equal(t, proto.TaskStatePaused, taskBase.State) found, err := c.TaskMgr.ResumeTask(c.Ctx, theTask.Key) require.NoError(t, err) require.True(t, found) // finish subtasks for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.EqualValues(t, []subtaskRuntimeInfo{ {Step: proto.StepOne, Concurrency: 7}, {Step: proto.StepOne, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, }, runtimeInfo.subtaskInfos) }) t.Run("modify pending task concurrency, but other owner already done it", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once modifySyncCh := make(chan struct{}) var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { task, err := handle.SubmitTask(c.Ctx, "k4", proto.TaskTypeExample, "", 3, scope, 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStatePending, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 7}, }, })) theTask = task gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID) require.NoError(t, err) require.Equal(t, proto.TaskStateModifying, gotTask.State) require.Equal(t, 3, gotTask.Concurrency) }) }) var onceForRefresh sync.Once testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/afterRefreshTask", func(task *proto.Task) { onceForRefresh.Do(func() { require.Equal(t, proto.TaskStateModifying, task.State) taskClone := *task taskClone.Concurrency = 7 require.NoError(t, c.TaskMgr.ModifiedTask(c.Ctx, &taskClone)) gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, task.ID) require.NoError(t, err) require.Equal(t, proto.TaskStatePending, gotTask.State) <-modifySyncCh }) }, ) modifySyncCh <- struct{}{} // finish subtasks for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.EqualValues(t, []subtaskRuntimeInfo{ {Step: proto.StepOne, Concurrency: 7}, {Step: proto.StepOne, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, {Step: proto.StepTwo, Concurrency: 7}, }, runtimeInfo.subtaskInfos) }) t.Run("modify pending task meta, only check the scheduler part", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once modifySyncCh := make(chan struct{}) var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { task, err := handle.SubmitTask(c.Ctx, "k5", proto.TaskTypeExample, "", 3, scope, 0, []byte("init")) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.EqualValues(t, []byte("init"), task.Meta) require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStatePending, Modifications: []proto.Modification{ {Type: proto.ModifyMaxWriteSpeed, To: 123}, }, })) theTask = task gotTask, err := c.TaskMgr.GetTaskBaseByID(c.Ctx, theTask.ID) require.NoError(t, err) require.Equal(t, proto.TaskStateModifying, gotTask.State) require.Equal(t, 3, gotTask.Concurrency) <-modifySyncCh }) }) modifySyncCh <- struct{}{} // finish subtasks for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, theTask.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.EqualValues(t, []byte("modify_max_write_speed=123"), *runtimeInfo.currTaskMeta.Load()) }) t.Run("modify meta and increase concurrency when subtask is running, and apply success", func(t *testing.T) { defer resetRuntimeInfoFn() defer testModifyWhenSubtaskRun.Store(false) testModifyWhenSubtaskRun.Store(true) task, err := handle.SubmitTask(c.Ctx, "k6", proto.TaskTypeExample, "", 3, scope, 0, []byte("init")) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.EqualValues(t, []byte("init"), task.Meta) modifyWaitCh <- struct{}{} require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStateRunning, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 7}, {Type: proto.ModifyMaxWriteSpeed, To: 123}, }, })) require.Eventually(t, func() bool { return "modify_max_write_speed=123" == string(*runtimeInfo.currTaskMeta.Load()) && runtimeInfo.currTaskConcurrency.Load() == 7 }, 10*time.Second, 100*time.Millisecond) testModifyWhenSubtaskRun.Store(false) modifyWaitCh <- struct{}{} // finish subtasks for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.Equal(t, "modify_max_write_speed=123", string(*runtimeInfo.currTaskMeta.Load())) require.Equal(t, int64(7), runtimeInfo.currTaskConcurrency.Load()) }) t.Run("modify meta and decrease concurrency when subtask is running, and apply success", func(t *testing.T) { defer resetRuntimeInfoFn() defer testModifyWhenSubtaskRun.Store(false) testModifyWhenSubtaskRun.Store(true) task, err := handle.SubmitTask(c.Ctx, "k7", proto.TaskTypeExample, "", 9, scope, 0, []byte("init")) require.NoError(t, err) require.Equal(t, 9, task.Concurrency) require.EqualValues(t, []byte("init"), task.Meta) modifyWaitCh <- struct{}{} require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStateRunning, Modifications: []proto.Modification{ {Type: proto.ModifyConcurrency, To: 5}, {Type: proto.ModifyMaxWriteSpeed, To: 456}, }, })) require.Eventually(t, func() bool { return "modify_max_write_speed=456" == string(*runtimeInfo.currTaskMeta.Load()) && runtimeInfo.currTaskConcurrency.Load() == 5 }, 10*time.Second, 100*time.Millisecond) testModifyWhenSubtaskRun.Store(false) modifyWaitCh <- struct{}{} // finish subtasks for range 5 { subtaskCh <- struct{}{} } task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) require.Equal(t, "modify_max_write_speed=456", string(*runtimeInfo.currTaskMeta.Load())) require.Equal(t, int64(5), runtimeInfo.currTaskConcurrency.Load()) }) t.Run("modify running task max node count", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once modifySyncCh := make(chan struct{}) testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) { if task.State != proto.TaskStateRunning && task.Step != proto.StepOne { return } once.Do(func() { require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStateRunning, Modifications: []proto.Modification{ {Type: proto.ModifyMaxNodeCount, To: 200}, }, })) <-modifySyncCh }) }) task, err := handle.SubmitTask(c.Ctx, "k8", proto.TaskTypeExample, "", 3, scope, 1, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.EqualValues(t, 1, task.MaxNodeCount) // finish StepOne subtaskCh <- struct{}{} subtaskCh <- struct{}{} // wait task move to 'modifying' state modifySyncCh <- struct{}{} // wait task move back to 'running' state, and modified require.Eventually(t, func() bool { gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID) require.NoError(t, err2) return gotTask.State == proto.TaskStateRunning && gotTask.MaxNodeCount == 200 }, 10*time.Second, 100*time.Millisecond) // finish StepTwo subtaskCh <- struct{}{} subtaskCh <- struct{}{} subtaskCh <- struct{}{} task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) }) } func TestModifyTaskMaxNodeCountForSubtaskBalance(t *testing.T) { c, runtimeInfo, subtaskCh, _, _ := prepareModifyTaskTest(t, 3) resetRuntimeInfoFn := func() { *runtimeInfo = collectedRuntimeInfo{} } scope := handle.GetTargetScope() t.Run("modify running task max node count, task can use more node after balance", func(t *testing.T) { defer resetRuntimeInfoFn() var once sync.Once modifySyncCh := make(chan struct{}) testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeRefreshTask", func(task *proto.Task) { if task.State != proto.TaskStateRunning && task.Step != proto.StepOne { return } once.Do(func() { <-modifySyncCh require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ PrevState: proto.TaskStateRunning, Modifications: []proto.Modification{ {Type: proto.ModifyMaxNodeCount, To: 2}, }, })) <-modifySyncCh }) }) task, err := handle.SubmitTask(c.Ctx, "k8", proto.TaskTypeExample, "", 3, scope, 1, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.EqualValues(t, 1, task.MaxNodeCount) // only 1 subtask can be running at the same time require.Eventually(t, func() bool { return runtimeInfo.activeSubtaskCount.Load() == 1 }, 10*time.Second, 100*time.Millisecond) require.NoError(t, c.TaskMgr.WithNewSession(func(se sessionctx.Context) error { rows, err2 := sqlexec.ExecSQL(c.Ctx, se.GetSQLExecutor(), "select count(distinct exec_id) from mysql.tidb_background_subtask where task_key = %?", task.ID) require.NoError(t, err2) require.Equal(t, 1, len(rows)) require.EqualValues(t, 1, rows[0].GetInt64(0)) return nil })) // task move to 'modifying' state modifySyncCh <- struct{}{} modifySyncCh <- struct{}{} // wait task move back to 'running' state, and modified require.Eventually(t, func() bool { gotTask, err2 := c.TaskMgr.GetTaskByID(c.Ctx, task.ID) require.NoError(t, err2) return gotTask.State == proto.TaskStateRunning && gotTask.MaxNodeCount == 2 }, 10*time.Second, 100*time.Millisecond) // now 2 subtask can be running at the same time require.Eventually(t, func() bool { return runtimeInfo.activeSubtaskCount.Load() == 2 }, 10*time.Second, 100*time.Millisecond) require.NoError(t, c.TaskMgr.WithNewSession(func(se sessionctx.Context) error { rows, err2 := sqlexec.ExecSQL(c.Ctx, se.GetSQLExecutor(), "select count(distinct exec_id) from mysql.tidb_background_subtask where task_key = %?", task.ID) require.NoError(t, err2) require.Equal(t, 1, len(rows)) require.EqualValues(t, 2, rows[0].GetInt64(0)) return nil })) // finish StepOne subtaskCh <- struct{}{} subtaskCh <- struct{}{} // finish StepTwo subtaskCh <- struct{}{} subtaskCh <- struct{}{} subtaskCh <- struct{}{} task2Base := testutil.WaitTaskDone(c.Ctx, t, task.Key) require.Equal(t, proto.TaskStateSucceed, task2Base.State) }) }