disttask: dynamic dispatch subtasks (#46593)

ref pingcap/tidb#46258
This commit is contained in:
EasonBall
2023-09-12 09:00:09 +08:00
committed by GitHub
parent 4716f1b070
commit dbb493ff22
21 changed files with 525 additions and 195 deletions

View File

@ -50,11 +50,12 @@ func NewBackfillingDispatcherExt(d DDL) (dispatcher.Extension, error) {
}, nil
}
// OnTick implements dispatcher.Extension interface.
func (*backfillingDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
// OnNextStage generate next stage's plan.
func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) {
// OnNextSubtasksBatch generate batch of next stage's plan.
func (h *backfillingDispatcherExt) OnNextSubtasksBatch(ctx context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) ([][]byte, error) {
var globalTaskMeta BackfillGlobalMeta
if err := json.Unmarshal(gTask.Meta, &globalTaskMeta); err != nil {
return nil, err
@ -76,7 +77,6 @@ func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher
if err != nil {
return nil, err
}
gTask.Step = proto.StepOne
return subTaskMetas, nil
}
@ -87,7 +87,6 @@ func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher
if err != nil {
return nil, err
}
gTask.Step = proto.StepOne
return subtaskMeta, nil
case proto.StepOne:
serverNodes, err := dispatcher.GenerateSchedulerNodes(ctx)
@ -103,15 +102,22 @@ func (h *backfillingDispatcherExt) OnNextStage(ctx context.Context, _ dispatcher
for range serverNodes {
subTaskMetas = append(subTaskMetas, metaBytes)
}
gTask.Step = proto.StepTwo
return subTaskMetas, nil
case proto.StepTwo:
return nil, nil
default:
return nil, nil
}
}
// StageFinished check if current stage finished.
func (*backfillingDispatcherExt) StageFinished(_ *proto.Task) bool {
return true
}
// Finished check if current task finished.
func (*backfillingDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne
}
// OnErrStage generate error handling stage's plan.
func (*backfillingDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task, receiveErr []error) (meta []byte, err error) {
// We do not need extra meta info when rolling back

View File

@ -38,7 +38,7 @@ func TestBackfillingDispatcher(t *testing.T) {
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
// test partition table OnNextStage.
/// 1. test partition table.
tk.MustExec("create table tp1(id int primary key, v int) PARTITION BY RANGE (id) (\n " +
"PARTITION p0 VALUES LESS THAN (10),\n" +
"PARTITION p1 VALUES LESS THAN (100),\n" +
@ -48,9 +48,10 @@ func TestBackfillingDispatcher(t *testing.T) {
tbl, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("tp1"))
require.NoError(t, err)
tblInfo := tbl.Meta()
metas, err := dsp.OnNextStage(context.Background(), nil, gTask)
// 1.1 OnNextSubtasksBatch
metas, err := dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, proto.StepOne, gTask.Step)
require.Equal(t, len(tblInfo.Partition.Definitions), len(metas))
for i, par := range tblInfo.Partition.Definitions {
var subTask ddl.BackfillSubTaskMeta
@ -58,13 +59,14 @@ func TestBackfillingDispatcher(t *testing.T) {
require.Equal(t, par.ID, subTask.PhysicalTableID)
}
// test partition table OnNextStage after step1 finished.
// 1.2 test partition table OnNextSubtasksBatch after StepInit finished.
gTask.State = proto.TaskStateRunning
metas, err = dsp.OnNextStage(context.Background(), nil, gTask)
gTask.Step++
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
// test partition table OnErrStage.
// 1.3 test partition table OnErrStage.
errMeta, err := dsp.OnErrStage(context.Background(), nil, gTask, []error{errors.New("mockErr")})
require.NoError(t, err)
require.Nil(t, errMeta)
@ -73,10 +75,30 @@ func TestBackfillingDispatcher(t *testing.T) {
require.NoError(t, err)
require.Nil(t, errMeta)
/// 2. test non partition table.
// 2.1 empty table
tk.MustExec("create table t1(id int primary key, v int)")
gTask = createAddIndexGlobalTask(t, dom, "test", "t1", ddl.BackfillTaskType)
_, err = dsp.OnNextStage(context.Background(), nil, gTask)
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 0, len(metas))
// 2.2 non empty table.
tk.MustExec("create table t2(id bigint auto_random primary key)")
tk.MustExec("insert into t2 values (), (), (), (), (), ()")
tk.MustExec("insert into t2 values (), (), (), (), (), ()")
tk.MustExec("insert into t2 values (), (), (), (), (), ()")
tk.MustExec("insert into t2 values (), (), (), (), (), ()")
gTask = createAddIndexGlobalTask(t, dom, "test", "t2", ddl.BackfillTaskType)
// 2.2.1 stepInit
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
// 2.2.2 stepOne
gTask.Step++
gTask.State = proto.TaskStateRunning
metas, err = dsp.OnNextSubtasksBatch(context.Background(), nil, gTask)
require.NoError(t, err)
require.Equal(t, 1, len(metas))
}
func createAddIndexGlobalTask(t *testing.T, dom *domain.Domain, dbName, tblName string, taskType string) *proto.Task {

View File

@ -1942,7 +1942,7 @@ func (w *worker) updateJobRowCount(taskKey string, jobID int64) {
logutil.BgLogger().Warn("cannot get global task", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return
}
rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepOne)
rowCount, err := taskMgr.GetSubtaskRowCount(gTask.ID, proto.StepInit)
if err != nil {
logutil.BgLogger().Warn("cannot get subtask row count", zap.String("category", "ddl"), zap.String("task_key", taskKey), zap.Error(err))
return

View File

@ -98,9 +98,9 @@ func newBackfillDistScheduler(ctx context.Context, id string, taskID int64, task
func (s *backfillDistScheduler) GetSubtaskExecutor(ctx context.Context, task *proto.Task, summary *execute.Summary) (execute.SubtaskExecutor, error) {
switch task.Step {
case proto.StepOne:
case proto.StepInit:
return NewBackfillSchedulerHandle(ctx, task.Meta, s.d, false, summary)
case proto.StepTwo:
case proto.StepOne:
return NewBackfillSchedulerHandle(ctx, task.Meta, s.d, true, nil)
default:
return nil, errors.Errorf("unknown backfill step %d for task %d", task.Step, task.ID)

View File

@ -4,6 +4,7 @@ go_test(
name = "framework_test",
timeout = "short",
srcs = [
"framework_dynamic_dispatch_test.go",
"framework_err_handling_test.go",
"framework_ha_test.go",
"framework_rollback_test.go",
@ -11,7 +12,7 @@ go_test(
],
flaky = True,
race = "off",
shard_count = 24,
shard_count = 26,
deps = [
"//disttask/framework/dispatcher",
"//disttask/framework/mock",

View File

@ -57,6 +57,8 @@ var (
type TaskHandle interface {
// GetPreviousSubtaskMetas gets previous subtask metas.
GetPreviousSubtaskMetas(taskID int64, step int64) ([][]byte, error)
// UpdateTask update the task in tidb_global_task table.
UpdateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) error
storage.SessionExecutor
}
@ -185,7 +187,7 @@ func (d *BaseDispatcher) scheduleTask() {
// handle task in cancelling state, dispatch revert subtasks.
func (d *BaseDispatcher) onCancelling() error {
logutil.Logger(d.logCtx).Debug("on cancelling state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
logutil.Logger(d.logCtx).Info("on cancelling state", zap.String("state", d.Task.State), zap.Int64("stage", d.Task.Step))
errs := []error{errors.New("cancel")}
return d.onErrHandlingStage(errs)
}
@ -198,11 +200,10 @@ func (d *BaseDispatcher) onReverting() error {
logutil.Logger(d.logCtx).Warn("check task failed", zap.Error(err))
return err
}
prevStageFinished := cnt == 0
if prevStageFinished {
if cnt == 0 {
// Finish the rollback step.
logutil.Logger(d.logCtx).Info("update the task to reverted state")
return d.updateTask(proto.TaskStateReverted, nil, RetrySQLTimes)
logutil.Logger(d.logCtx).Info("all reverting tasks finished, update the task to reverted state")
return d.UpdateTask(proto.TaskStateReverted, nil, RetrySQLTimes)
}
// Wait all subtasks in this stage finished.
d.OnTick(d.ctx, d.Task)
@ -236,9 +237,19 @@ func (d *BaseDispatcher) onRunning() error {
return err
}
prevStageFinished := cnt == 0
if prevStageFinished {
logutil.Logger(d.logCtx).Info("previous stage finished, generate dist plan", zap.Int64("stage", d.Task.Step))
if cnt == 0 {
logutil.Logger(d.logCtx).Info("previous subtasks finished, generate dist plan", zap.Int64("stage", d.Task.Step))
// When all subtasks dispatched and processed, mark task as succeed.
if d.Finished(d.Task) {
d.Task.StateUpdateTime = time.Now().UTC()
logutil.Logger(d.logCtx).Info("all subtasks dispatched and processed, finish the task")
err := d.UpdateTask(proto.TaskStateSucceed, nil, RetrySQLTimes)
if err != nil {
logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err))
return err
}
return nil
}
return d.onNextStage()
}
// Check if any node are down.
@ -309,7 +320,29 @@ func (d *BaseDispatcher) replaceDeadNodesIfAny() error {
return nil
}
func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
func (d *BaseDispatcher) addSubtasks(subtasks []*proto.Subtask) (err error) {
for i := 0; i < RetrySQLTimes; i++ {
err = d.taskMgr.AddSubTasks(d.Task, subtasks)
if err == nil {
break
}
if i%10 == 0 {
logutil.Logger(d.logCtx).Warn("addSubtasks failed", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step),
zap.Int("subtask cnt", len(subtasks)),
zap.Int("retry times", i), zap.Error(err))
}
time.Sleep(RetrySQLInterval)
}
if err != nil {
logutil.Logger(d.logCtx).Warn("addSubtasks failed", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step),
zap.Int("subtask cnt", len(subtasks)),
zap.Int("retry times", RetrySQLTimes), zap.Error(err))
}
return err
}
// UpdateTask update the task in tidb_global_task table.
func (d *BaseDispatcher) UpdateTask(taskState string, newSubTasks []*proto.Subtask, retryTimes int) (err error) {
prevState := d.Task.State
d.Task.State = taskState
if !VerifyTaskStateTransform(prevState, taskState) {
@ -330,7 +363,7 @@ func (d *BaseDispatcher) updateTask(taskState string, newSubTasks []*proto.Subta
}
if i%10 == 0 {
logutil.Logger(d.logCtx).Warn("updateTask first failed", zap.String("from", prevState), zap.String("to", d.Task.State),
zap.Int("retry times", retryTimes), zap.Error(err))
zap.Int("retry times", i), zap.Error(err))
}
time.Sleep(RetrySQLInterval)
}
@ -351,11 +384,11 @@ func (d *BaseDispatcher) onErrHandlingStage(receiveErr []error) error {
}
// 2. dispatch revert dist-plan to EligibleInstances.
return d.dispatchSubTask4Revert(d.Task, meta)
return d.dispatchSubTask4Revert(meta)
}
func (d *BaseDispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) error {
instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, task)
func (d *BaseDispatcher) dispatchSubTask4Revert(meta []byte) error {
instanceIDs, err := d.GetAllSchedulerIDs(d.ctx, d.Task)
if err != nil {
logutil.Logger(d.logCtx).Warn("get task's all instances failed", zap.Error(err))
return err
@ -363,54 +396,74 @@ func (d *BaseDispatcher) dispatchSubTask4Revert(task *proto.Task, meta []byte) e
subTasks := make([]*proto.Subtask, 0, len(instanceIDs))
for _, id := range instanceIDs {
subTasks = append(subTasks, proto.NewSubtask(task.ID, task.Type, id, meta))
subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, id, meta))
}
return d.updateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes)
return d.UpdateTask(proto.TaskStateReverting, subTasks, RetrySQLTimes)
}
func (d *BaseDispatcher) onNextStage() error {
// 1. generate the needed global task meta and subTask meta (dist-plan).
metas, err := d.OnNextStage(d.ctx, d, d.Task)
if err != nil {
return d.handlePlanErr(err)
}
// 2. dispatch dist-plan to EligibleInstances.
return d.dispatchSubTask(d.Task, metas)
}
/// dynamic dispatch subtasks.
failpoint.Inject("mockDynamicDispatchErr", func() {
failpoint.Return(errors.New("mockDynamicDispatchErr"))
})
func (d *BaseDispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.Task.State), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))
// 1. Adjust the global task's concurrency.
if task.Concurrency == 0 {
task.Concurrency = DefaultSubtaskConcurrency
}
if task.Concurrency > MaxSubtaskConcurrency {
task.Concurrency = MaxSubtaskConcurrency
}
retryTimes := RetrySQLTimes
// 2. Special handling for the new tasks.
if task.State == proto.TaskStatePending {
// TODO: Consider using TS.
nowTime := time.Now().UTC()
task.StartTime = nowTime
task.StateUpdateTime = nowTime
retryTimes = nonRetrySQLTime
}
if len(metas) == 0 {
task.StateUpdateTime = time.Now().UTC()
// Write the global task meta into the storage.
err := d.updateTask(proto.TaskStateSucceed, nil, retryTimes)
if err != nil {
logutil.Logger(d.logCtx).Warn("update task failed", zap.Error(err))
if d.Task.State == proto.TaskStatePending {
if d.Task.Concurrency == 0 {
d.Task.Concurrency = DefaultSubtaskConcurrency
}
if d.Task.Concurrency > MaxSubtaskConcurrency {
d.Task.Concurrency = MaxSubtaskConcurrency
}
d.Task.StateUpdateTime = time.Now().UTC()
if err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes); err != nil {
return err
}
} else if d.StageFinished(d.Task) {
// 2. when previous stage finished, update to next stage.
d.Task.Step++
logutil.Logger(d.logCtx).Info("previous stage finished, run into next stage", zap.Int64("from", d.Task.Step-1), zap.Int64("to", d.Task.Step))
d.Task.StateUpdateTime = time.Now().UTC()
err := d.UpdateTask(proto.TaskStateRunning, nil, RetrySQLTimes)
if err != nil {
return err
}
return nil
}
// 3. select all available TiDB nodes for task.
serverNodes, err := d.GetEligibleInstances(d.ctx, task)
for {
// 3. generate a batch of subtasks.
metas, err := d.OnNextSubtasksBatch(d.ctx, d, d.Task)
if err != nil {
logutil.Logger(d.logCtx).Warn("generate part of subtasks failed", zap.Error(err))
return d.handlePlanErr(err)
}
failpoint.Inject("mockDynamicDispatchErr1", func() {
failpoint.Return(errors.New("mockDynamicDispatchErr1"))
})
// 4. dispatch batch of subtasks to EligibleInstances.
err = d.dispatchSubTask(metas)
if err != nil {
return err
}
if d.StageFinished(d.Task) {
break
}
failpoint.Inject("mockDynamicDispatchErr2", func() {
failpoint.Return(errors.New("mockDynamicDispatchErr2"))
})
}
return nil
}
func (d *BaseDispatcher) dispatchSubTask(metas [][]byte) error {
logutil.Logger(d.logCtx).Info("dispatch subtasks", zap.String("state", d.Task.State), zap.Int64("step", d.Task.Step), zap.Uint64("concurrency", d.Task.Concurrency), zap.Int("subtasks", len(metas)))
// select all available TiDB nodes for task.
serverNodes, err := d.GetEligibleInstances(d.ctx, d.Task)
logutil.Logger(d.logCtx).Debug("eligible instances", zap.Int("num", len(serverNodes)))
if err != nil {
@ -434,12 +487,13 @@ func (d *BaseDispatcher) dispatchSubTask(task *proto.Task, metas [][]byte) error
subTasks := make([]*proto.Subtask, 0, len(metas))
for i, meta := range metas {
// we assign the subtask to the instance in a round-robin way.
// TODO: assign the subtask to the instance according to the system load of each nodes
pos := i % len(serverNodes)
instanceID := disttaskutil.GenerateExecID(serverNodes[pos].IP, serverNodes[pos].Port)
logutil.Logger(d.logCtx).Debug("create subtasks", zap.String("instanceID", instanceID))
subTasks = append(subTasks, proto.NewSubtask(task.ID, task.Type, instanceID, meta))
subTasks = append(subTasks, proto.NewSubtask(d.Task.ID, d.Task.Type, instanceID, meta))
}
return d.updateTask(proto.TaskStateRunning, subTasks, RetrySQLTimes)
return d.addSubtasks(subTasks)
}
func (d *BaseDispatcher) handlePlanErr(err error) error {
@ -449,7 +503,7 @@ func (d *BaseDispatcher) handlePlanErr(err error) error {
}
d.Task.Error = err
// state transform: pending -> failed.
return d.updateTask(proto.TaskStateFailed, nil, RetrySQLTimes)
return d.UpdateTask(proto.TaskStateFailed, nil, RetrySQLTimes)
}
// GenerateSchedulerNodes generate a eligible TiDB nodes.

View File

@ -48,7 +48,7 @@ type testDispatcherExt struct{}
func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) {
func (*testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) {
return nil, nil
}
@ -66,12 +66,20 @@ func (*testDispatcherExt) IsRetryableErr(error) bool {
return true
}
func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool {
return true
}
func (dsp *testDispatcherExt) Finished(task *proto.Task) bool {
return false
}
type numberExampleDispatcherExt struct{}
func (*numberExampleDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (n *numberExampleDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) {
func (n *numberExampleDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, task *proto.Task) (metas [][]byte, err error) {
if task.State == proto.TaskStatePending {
task.Step = proto.StepInit
}
@ -104,6 +112,14 @@ func (*numberExampleDispatcherExt) IsRetryableErr(error) bool {
return true
}
func (*numberExampleDispatcherExt) StageFinished(task *proto.Task) bool {
return true
}
func (*numberExampleDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepTwo
}
func MockDispatcherManager(t *testing.T, pool *pools.ResourcePool) (*dispatcher.Manager, *storage.TaskManager) {
ctx := context.WithValue(context.Background(), "etcd", true)
mgr := storage.NewTaskManager(util.WithInternalSourceType(ctx, "taskManager"), pool)
@ -255,7 +271,7 @@ func checkDispatch(t *testing.T, taskCnt int, isSucc bool, isCancel bool) {
require.NoError(t, err)
taskIDs = append(taskIDs, taskID)
}
// test OnNextStage.
// test OnNextSubtasksBatch.
checkGetRunningTaskCnt(taskCnt)
tasks := checkTaskRunningCnt()
for i, taskID := range taskIDs {

View File

@ -31,23 +31,35 @@ type Extension interface {
// OnTick is used to handle the ticker event, if business impl need to do some periodical work, you can
// do it here, but don't do too much work here, because the ticker interval is small, and it will block
// the event is generated every checkTaskRunningInterval, and only when the task NOT FINISHED and NO ERROR.
OnTick(ctx context.Context, gTask *proto.Task)
// OnNextStage is used to move the task to next stage, if returns no error and there's no new subtasks
// the task is finished.
OnTick(ctx context.Context, task *proto.Task)
// OnNextSubtasksBatch is used to generate batch of subtasks for current stage
// NOTE: don't change gTask.State inside, framework will manage it.
// it's called when:
// 1. task is pending and entering it's first step.
// 2. subtasks of previous step has all finished with no error.
OnNextStage(ctx context.Context, h TaskHandle, gTask *proto.Task) (subtaskMetas [][]byte, err error)
// 2. subtasks dispatched has all finished with no error.
OnNextSubtasksBatch(ctx context.Context, h TaskHandle, task *proto.Task) (subtaskMetas [][]byte, err error)
// OnErrStage is called when:
// 1. subtask is finished with error.
// 2. task is cancelled after we have dispatched some subtasks.
OnErrStage(ctx context.Context, h TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error)
OnErrStage(ctx context.Context, h TaskHandle, task *proto.Task, receiveErr []error) (subtaskMeta []byte, err error)
// GetEligibleInstances is used to get the eligible instances for the task.
// on certain condition we may want to use some instances to do the task, such as instances with more disk.
GetEligibleInstances(ctx context.Context, gTask *proto.Task) ([]*infosync.ServerInfo, error)
GetEligibleInstances(ctx context.Context, task *proto.Task) ([]*infosync.ServerInfo, error)
// IsRetryableErr is used to check whether the error occurred in dispatcher is retryable.
IsRetryableErr(err error) bool
// StageFinished is used to check if all subtasks in current stage are dispatched and processed.
// StageFinished is called before generating batch of subtasks.
StageFinished(task *proto.Task) bool
// Finished is used to check if all subtasks for the task are dispatched and processed.
// Finished is called before generating batch of subtasks.
// Once Finished return true, mark the task as succeed.
Finished(task *proto.Task) bool
}
// FactoryFn is used to create a dispatcher.

View File

@ -0,0 +1,114 @@
// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package framework_test
import (
"context"
"fmt"
"sync"
"testing"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/disttask/framework/dispatcher"
"github.com/pingcap/tidb/disttask/framework/proto"
"github.com/pingcap/tidb/domain/infosync"
"github.com/pingcap/tidb/testkit"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
type testDynamicDispatcherExt struct {
cnt int
}
var _ dispatcher.Extension = (*testDynamicDispatcherExt)(nil)
func (*testDynamicDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {}
func (dsp *testDynamicDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
// step1
if gTask.Step == proto.StepInit && dsp.cnt < 3 {
dsp.cnt++
return [][]byte{
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
}, nil
}
// step2
if gTask.Step == proto.StepOne && dsp.cnt < 4 {
dsp.cnt++
return [][]byte{
[]byte(fmt.Sprintf("task%d", dsp.cnt)),
}, nil
}
return nil, nil
}
func (*testDynamicDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle, _ *proto.Task, _ []error) (meta []byte, err error) {
return nil, nil
}
func (dsp *testDynamicDispatcherExt) StageFinished(task *proto.Task) bool {
if task.Step == proto.StepInit && dsp.cnt >= 3 {
return true
}
if task.Step == proto.StepOne && dsp.cnt >= 4 {
return true
}
return false
}
func (dsp *testDynamicDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne && dsp.cnt >= 4
}
func (*testDynamicDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return generateSchedulerNodes4Test()
}
func (*testDynamicDispatcherExt) IsRetryableErr(error) bool {
return true
}
func TestFrameworkDynamicBasic(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &testDynamicDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 3)
DispatchTaskAndCheckSuccess("key1", t, &m)
distContext.Close()
}
func TestFrameworkDynamicHA(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &testDynamicDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 3)
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr", "5*return()"))
DispatchTaskAndCheckSuccess("key1", t, &m)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr1", "5*return()"))
DispatchTaskAndCheckSuccess("key2", t, &m)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr1"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr2", "5*return()"))
DispatchTaskAndCheckSuccess("key3", t, &m)
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/mockDynamicDispatchErr2"))
distContext.Close()
}

View File

@ -29,6 +29,7 @@ import (
type planErrDispatcherExt struct {
callTime int
cnt int
}
var (
@ -39,13 +40,13 @@ var (
func (*planErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (p *planErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.State == proto.TaskStatePending {
func (p *planErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
if p.callTime == 0 {
p.callTime++
return nil, errors.New("retryable err")
}
gTask.Step = proto.StepOne
p.cnt = 3
return [][]byte{
[]byte("task1"),
[]byte("task2"),
@ -53,7 +54,7 @@ func (p *planErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskH
}, nil
}
if gTask.Step == proto.StepOne {
gTask.Step = proto.StepTwo
p.cnt = 4
return [][]byte{
[]byte("task4"),
}, nil
@ -77,13 +78,31 @@ func (*planErrDispatcherExt) IsRetryableErr(error) bool {
return true
}
func (p *planErrDispatcherExt) StageFinished(task *proto.Task) bool {
if task.Step == proto.StepInit && p.cnt == 3 {
return true
}
if task.Step == proto.StepOne && p.cnt == 4 {
return true
}
return false
}
func (p *planErrDispatcherExt) Finished(task *proto.Task) bool {
if task.Step == proto.StepOne && p.cnt == 4 {
return true
}
return false
}
type planNotRetryableErrDispatcherExt struct {
cnt int
}
func (*planNotRetryableErrDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (p *planNotRetryableErrDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
func (p *planNotRetryableErrDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
return nil, errors.New("not retryable err")
}
@ -99,12 +118,25 @@ func (*planNotRetryableErrDispatcherExt) IsRetryableErr(error) bool {
return false
}
func (p *planNotRetryableErrDispatcherExt) StageFinished(task *proto.Task) bool {
if task.Step == proto.StepInit && p.cnt >= 3 {
return true
}
if task.Step == proto.StepOne && p.cnt >= 4 {
return true
}
return false
}
func (p *planNotRetryableErrDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne && p.cnt >= 4
}
func TestPlanErr(t *testing.T) {
m := sync.Map{}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0})
RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0, 0})
distContext := testkit.NewDistExecutionContext(t, 2)
DispatchTaskAndCheckSuccess("key1", t, &m)
distContext.Close()
@ -115,7 +147,7 @@ func TestRevertPlanErr(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0})
RegisterTaskMeta(t, ctrl, &m, &planErrDispatcherExt{0, 0})
distContext := testkit.NewDistExecutionContext(t, 2)
DispatchTaskAndCheckSuccess("key1", t, &m)
distContext.Close()
@ -123,7 +155,6 @@ func TestRevertPlanErr(t *testing.T) {
func TestPlanNotRetryableErr(t *testing.T) {
m := sync.Map{}
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &planNotRetryableErrDispatcherExt{})

View File

@ -28,16 +28,18 @@ import (
"go.uber.org/mock/gomock"
)
type haTestFlowHandle struct{}
var _ dispatcher.Extension = (*haTestFlowHandle)(nil)
func (*haTestFlowHandle) OnTick(_ context.Context, _ *proto.Task) {
type haTestDispatcherExt struct {
cnt int
}
func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.State == proto.TaskStatePending {
gTask.Step = proto.StepOne
var _ dispatcher.Extension = (*haTestDispatcherExt)(nil)
func (*haTestDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (dsp *haTestDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 10
return [][]byte{
[]byte("task1"),
[]byte("task2"),
@ -52,7 +54,7 @@ func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle,
}, nil
}
if gTask.Step == proto.StepOne {
gTask.Step = proto.StepTwo
dsp.cnt = 15
return [][]byte{
[]byte("task11"),
[]byte("task12"),
@ -64,23 +66,37 @@ func (*haTestFlowHandle) OnNextStage(_ context.Context, _ dispatcher.TaskHandle,
return nil, nil
}
func (*haTestFlowHandle) OnErrStage(ctx context.Context, h dispatcher.TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) {
func (*haTestDispatcherExt) OnErrStage(ctx context.Context, h dispatcher.TaskHandle, gTask *proto.Task, receiveErr []error) (subtaskMeta []byte, err error) {
return nil, nil
}
func (*haTestFlowHandle) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
func (*haTestDispatcherExt) GetEligibleInstances(_ context.Context, _ *proto.Task) ([]*infosync.ServerInfo, error) {
return generateSchedulerNodes4Test()
}
func (*haTestFlowHandle) IsRetryableErr(error) bool {
func (*haTestDispatcherExt) IsRetryableErr(error) bool {
return true
}
func (dsp *haTestDispatcherExt) StageFinished(task *proto.Task) bool {
if task.Step == proto.StepInit && dsp.cnt >= 10 {
return true
}
if task.Step == proto.StepOne && dsp.cnt >= 15 {
return true
}
return false
}
func (dsp *haTestDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne && dsp.cnt >= 15
}
func TestHABasic(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{})
RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 4)
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "4*return()"))
@ -94,10 +110,9 @@ func TestHABasic(t *testing.T) {
func TestHAManyNodes(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{})
RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 30)
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler", "return()"))
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager", "30*return()"))
@ -111,10 +126,9 @@ func TestHAManyNodes(t *testing.T) {
func TestHAFailInDifferentStage(t *testing.T) {
var m sync.Map
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{})
RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 6)
// stage1 : server num from 6 to 3.
// stage2 : server num from 3 to 2.
@ -136,7 +150,7 @@ func TestHAFailInDifferentStageManyNodes(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{})
RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 30)
// stage1 : server num from 30 to 27.
// stage2 : server num from 27 to 26.
@ -158,7 +172,7 @@ func TestHAReplacedButRunning(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{})
RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 4)
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume", "10*return(true)"))
DispatchTaskAndCheckSuccess("😊", t, &m)
@ -171,7 +185,7 @@ func TestHAReplacedButRunningManyNodes(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
RegisterTaskMeta(t, ctrl, &m, &haTestFlowHandle{})
RegisterTaskMeta(t, ctrl, &m, &haTestDispatcherExt{})
distContext := testkit.NewDistExecutionContext(t, 30)
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBPartitionThenResume", "30*return(true)"))
DispatchTaskAndCheckSuccess("😊", t, &m)

View File

@ -30,7 +30,9 @@ import (
"go.uber.org/mock/gomock"
)
type rollbackDispatcherExt struct{}
type rollbackDispatcherExt struct {
cnt int
}
var _ dispatcher.Extension = (*rollbackDispatcherExt)(nil)
var rollbackCnt atomic.Int32
@ -38,9 +40,9 @@ var rollbackCnt atomic.Int32
func (*rollbackDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (*rollbackDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.State == proto.TaskStatePending {
gTask.Step = proto.StepOne
func (dsp *rollbackDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 3
return [][]byte{
[]byte("task1"),
[]byte("task2"),
@ -62,6 +64,14 @@ func (*rollbackDispatcherExt) IsRetryableErr(error) bool {
return true
}
func (dsp *rollbackDispatcherExt) StageFinished(task *proto.Task) bool {
return task.Step == proto.StepInit && dsp.cnt >= 3
}
func (dsp *rollbackDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepInit && dsp.cnt >= 3
}
type testRollbackMiniTask struct{}
func (testRollbackMiniTask) IsMinimalTask() {}
@ -125,7 +135,7 @@ func TestFrameworkRollback(t *testing.T) {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/dispatcher/cancelTaskAfterRefreshTask"))
}()
DispatchTaskAndCheckState("key2", t, &m, proto.TaskStateReverted)
DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted)
require.Equal(t, int32(2), rollbackCnt.Load())
rollbackCnt.Store(0)
distContext.Close()

View File

@ -35,16 +35,18 @@ import (
"go.uber.org/mock/gomock"
)
type testDispatcherExt struct{}
type testDispatcherExt struct {
cnt int
}
var _ dispatcher.Extension = (*testDispatcherExt)(nil)
func (*testDispatcherExt) OnTick(_ context.Context, _ *proto.Task) {
}
func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.State == proto.TaskStatePending {
gTask.Step = proto.StepOne
func (dsp *testDispatcherExt) OnNextSubtasksBatch(_ context.Context, _ dispatcher.TaskHandle, gTask *proto.Task) (metas [][]byte, err error) {
if gTask.Step == proto.StepInit {
dsp.cnt = 3
return [][]byte{
[]byte("task1"),
[]byte("task2"),
@ -52,7 +54,7 @@ func (*testDispatcherExt) OnNextStage(_ context.Context, _ dispatcher.TaskHandle
}, nil
}
if gTask.Step == proto.StepOne {
gTask.Step = proto.StepTwo
dsp.cnt = 4
return [][]byte{
[]byte("task4"),
}, nil
@ -64,6 +66,20 @@ func (*testDispatcherExt) OnErrStage(_ context.Context, _ dispatcher.TaskHandle,
return nil, nil
}
func (dsp *testDispatcherExt) StageFinished(task *proto.Task) bool {
if task.Step == proto.StepInit && dsp.cnt >= 3 {
return true
}
if task.Step == proto.StepOne && dsp.cnt >= 4 {
return true
}
return false
}
func (dsp *testDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == proto.StepOne && dsp.cnt >= 4
}
func generateSchedulerNodes4Test() ([]*infosync.ServerInfo, error) {
serverInfos := infosync.MockGlobalServerInfoManagerEntry.GetAllServerInfo()
if len(serverInfos) == 0 {
@ -174,9 +190,9 @@ func RegisterTaskMeta(t *testing.T, ctrl *gomock.Controller, m *sync.Map, dispat
mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) {
switch step {
case proto.StepOne:
case proto.StepInit:
return &testSubtaskExecutor{m: m}, nil
case proto.StepTwo:
case proto.StepOne:
return &testSubtaskExecutor1{m: m}, nil
}
panic("invalid step")
@ -212,9 +228,9 @@ func RegisterTaskMetaForExample2(t *testing.T, ctrl *gomock.Controller, m *sync.
mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) {
switch step {
case proto.StepOne:
case proto.StepInit:
return &testSubtaskExecutor2{m: m}, nil
case proto.StepTwo:
case proto.StepOne:
return &testSubtaskExecutor3{m: m}, nil
}
panic("invalid step")
@ -223,6 +239,10 @@ func RegisterTaskMetaForExample2(t *testing.T, ctrl *gomock.Controller, m *sync.
}
func RegisterTaskMetaForExample2Inner(t *testing.T, mockExtension scheduler.Extension, dispatcherHandle dispatcher.Extension) {
t.Cleanup(func() {
dispatcher.ClearDispatcherFactory()
scheduler.ClearSchedulers()
})
dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample2,
func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher {
baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task)
@ -244,9 +264,9 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync.
mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(minimalTask proto.MinimalTask, tp string, step int64) (execute.MiniTaskExecutor, error) {
switch step {
case proto.StepOne:
case proto.StepInit:
return &testSubtaskExecutor4{m: m}, nil
case proto.StepTwo:
case proto.StepOne:
return &testSubtaskExecutor5{m: m}, nil
}
panic("invalid step")
@ -255,6 +275,10 @@ func RegisterTaskMetaForExample3(t *testing.T, ctrl *gomock.Controller, m *sync.
}
func RegisterTaskMetaForExample3Inner(t *testing.T, mockExtension scheduler.Extension, dispatcherHandle dispatcher.Extension) {
t.Cleanup(func() {
dispatcher.ClearDispatcherFactory()
scheduler.ClearSchedulers()
})
dispatcher.RegisterDispatcherFactory(proto.TaskTypeExample3,
func(ctx context.Context, taskMgr *storage.TaskManager, serverID string, task *proto.Task) dispatcher.Dispatcher {
baseDispatcher := dispatcher.NewBaseDispatcher(ctx, taskMgr, serverID, task)
@ -500,6 +524,7 @@ func TestFrameworkSubTaskFailed(t *testing.T) {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/MockExecutorRunErr"))
}()
DispatchTaskAndCheckState("key1", t, &m, proto.TaskStateReverted)
distContext.Close()
}
@ -563,7 +588,6 @@ func TestSchedulerDownBasic(t *testing.T) {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager"))
distContext.Close()
}
@ -582,7 +606,6 @@ func TestSchedulerDownManyNodes(t *testing.T) {
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockCleanScheduler"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockTiDBDown"))
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/disttask/framework/scheduler/mockStopManager"))
distContext.Close()
}

View File

@ -48,7 +48,7 @@ const (
// TaskStep is the step of task.
const (
StepInit int64 = -1
StepInit int64 = 0
StepOne int64 = 1
StepTwo int64 = 2
)
@ -56,7 +56,7 @@ const (
// TaskIDLabelName is the label name of task id.
const TaskIDLabelName = "task_id"
// Task represents the task of distribute framework.
// Task represents the task of distributed framework.
type Task struct {
ID int64
Key string

View File

@ -301,8 +301,8 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str
v, ok := testContexts.Load(m.id)
if ok {
<-v.(*TestContext).TestSyncSubtaskRun
m.Stop()
_ = infosync.MockGlobalServerInfoManagerEntry.DeleteByID(m.id)
m.Stop()
}
}()
})

View File

@ -84,7 +84,6 @@ func (s *BaseScheduler) startCancelCheck(ctx context.Context, wg *sync.WaitGroup
return
case <-ticker.C:
canceled, err := s.taskTable.IsSchedulerCanceled(s.taskID, s.id)
logutil.Logger(s.logCtx).Info("scheduler before canceled")
if err != nil {
continue
}
@ -170,15 +169,25 @@ func (s *BaseScheduler) run(ctx context.Context, task *proto.Task) error {
subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, task.Step, proto.TaskStatePending)
if err != nil {
s.onError(err)
break
logutil.Logger(s.logCtx).Warn("GetSubtaskInStates meets error", zap.Error(err))
continue
}
if subtask == nil {
break
newTask, err := s.taskTable.GetGlobalTaskByID(task.ID)
if err != nil {
logutil.Logger(s.logCtx).Warn("GetGlobalTaskByID meets error", zap.Error(err))
continue
}
// When the task move to next step or task state changes, the scheduler should exit.
if newTask.Step != task.Step || newTask.State != task.State {
break
}
continue
}
s.startSubtask(subtask.ID)
if err := s.getError(); err != nil {
break
logutil.Logger(s.logCtx).Warn("startSubtask meets error", zap.Error(err))
continue
}
failpoint.Inject("mockCleanScheduler", func() {
v, ok := testContexts.Load(s.id)
@ -216,11 +225,13 @@ func (s *BaseScheduler) runSubtask(ctx context.Context, scheduler execute.Subtas
zap.Int64("subtask_step", subtask.Step))
failpoint.Inject("mockTiDBDown", func(val failpoint.Value) {
logutil.Logger(s.logCtx).Info("trigger mockTiDBDown")
if s.id == val.(string) || s.id == ":4001" || s.id == ":4002" {
v, ok := testContexts.Load(s.id)
if ok {
v.(*TestContext).TestSyncSubtaskRun <- struct{}{}
v.(*TestContext).mockDown.Store(true)
logutil.Logger(s.logCtx).Info("mockTiDBDown")
time.Sleep(2 * time.Second)
failpoint.Return()
}

View File

@ -108,28 +108,9 @@ func TestSchedulerRun(t *testing.T) {
err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp})
require.EqualError(t, err, poolErr.Error())
// 4. get subtask failed
getSubtaskErr := errors.New("get subtask error")
cleanupErr := errors.New("clean up error")
var taskID int64 = 1
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn)
mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(nil, getSubtaskErr)
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(cleanupErr)
err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID})
require.EqualError(t, err, getSubtaskErr.Error())
// 5. update subtask state failed
updateSubtaskErr := errors.New("update subtask error")
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn)
mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Step: proto.StepOne}, nil)
mockSubtaskTable.EXPECT().StartSubtask(taskID).Return(updateSubtaskErr)
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil)
err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID})
require.EqualError(t, err, updateSubtaskErr.Error())
// 6. no subtask executor constructor
// 4. no subtask executor constructor
subtaskExecutorRegisterErr := errors.Errorf("constructor of subtask executor for key not found")
mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil, subtaskExecutorRegisterErr)
var concurrency uint64 = 10
@ -145,7 +126,7 @@ func TestSchedulerRun(t *testing.T) {
mockExtension.EXPECT().GetMiniTaskExecutor(gomock.Any(), gomock.Any(), gomock.Any()).Return(mockMiniTaskExecutor, nil).AnyTimes()
// 7. run subtask failed
// 5. run subtask failed
runSubtaskErr := errors.New("run subtask error")
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn)
@ -154,11 +135,12 @@ func TestSchedulerRun(t *testing.T) {
mockSubtaskExecutor.EXPECT().SplitSubtask(gomock.Any(), gomock.Any()).Return([]proto.MinimalTask{MockMinimalTask{}}, nil)
mockMiniTaskExecutor.EXPECT().Run(gomock.Any()).Return(runSubtaskErr)
mockSubtaskTable.EXPECT().UpdateSubtaskStateAndError(taskID, proto.TaskStateFailed, gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetGlobalTaskByID(gomock.Any()).Return(&proto.Task{ID: taskID, Step: proto.StepTwo}, nil)
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil)
err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency})
require.EqualError(t, err, runSubtaskErr.Error())
// 8. run subtask success
// 6. run subtask success
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn)
mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Type: tp, Step: proto.StepOne}, nil)
@ -168,11 +150,12 @@ func TestSchedulerRun(t *testing.T) {
mockSubtaskExecutor.EXPECT().OnFinished(gomock.Any(), gomock.Any()).Return([]byte(""), nil)
mockSubtaskTable.EXPECT().FinishSubtask(int64(1), gomock.Any()).Return(nil)
mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(nil, nil)
mockSubtaskTable.EXPECT().GetGlobalTaskByID(gomock.Any()).Return(&proto.Task{ID: taskID, Step: proto.StepTwo}, nil)
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil)
err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency})
require.NoError(t, err)
// 9. run subtask one by one
// 7. run subtask one by one
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn)
mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Type: tp, Step: proto.StepOne}, nil)
@ -193,8 +176,7 @@ func TestSchedulerRun(t *testing.T) {
mockSubtaskExecutor.EXPECT().Cleanup(gomock.Any()).Return(nil)
err = scheduler.Run(runCtx, &proto.Task{Step: proto.StepOne, Type: tp, ID: taskID, Concurrency: concurrency})
require.NoError(t, err)
// 10. cancel
// 8. cancel
mockSubtaskExecutor.EXPECT().Init(gomock.Any()).Return(nil)
mockPool.EXPECT().RunWithConcurrency(gomock.Any(), gomock.Any()).DoAndReturn(runWithConcurrencyFn)
mockSubtaskTable.EXPECT().GetSubtaskInStates("id", taskID, proto.StepOne, []interface{}{proto.TaskStatePending}).Return(&proto.Subtask{ID: 1, Type: tp, Step: proto.StepOne}, nil)
@ -215,8 +197,7 @@ func TestSchedulerRun(t *testing.T) {
time.Sleep(time.Second)
runCancel()
wg.Wait()
// 11. run subtask one by one, on error, we should wait all minimal task finished before call Cleanup
// 9. run subtask one by one, on error, we should wait all minimal task finished before call Cleanup
syncCh := make(chan struct{})
lastMinimalTaskFinishTime, cleanupTime := time.Time{}, time.Time{}
require.NoError(t, failpoint.Enable("github.com/pingcap/tidb/disttask/framework/scheduler/waitUntilError", `return(true)`))

View File

@ -350,7 +350,7 @@ func (stm *TaskManager) UpdateErrorToSubtask(tidbID string, taskID int64, err er
}
// PrintSubtaskInfo log the subtask info by taskKey.
func (stm *TaskManager) PrintSubtaskInfo(taskKey int) {
func (stm *TaskManager) PrintSubtaskInfo(taskKey int64) {
rs, _ := stm.executeSQLWithNewSession(stm.ctx,
"select * from mysql.tidb_background_subtask where task_key = %?", taskKey)
@ -609,6 +609,27 @@ func (stm *TaskManager) UpdateFailedSchedulerIDs(taskID int64, replaceNodes map[
return err
}
// AddSubTasks add new batch of subtasks.
func (stm *TaskManager) AddSubTasks(task *proto.Task, subtasks []*proto.Subtask) error {
err := stm.WithNewTxn(stm.ctx, func(se sessionctx.Context) error {
for _, subtask := range subtasks {
subtaskState := proto.TaskStatePending
if task.State == proto.TaskStateReverting {
subtaskState = proto.TaskStateRevertPending
}
_, err := ExecSQL(stm.ctx, se, `insert into mysql.tidb_background_subtask
(step, task_key, exec_id, meta, state, type, checkpoint, summary)
values (%?, %?, %?, %?, %?, %?, %?, %?)`,
task.Step, task.ID, subtask.SchedulerID, subtask.Meta, subtaskState, proto.Type2Int(subtask.Type), []byte{}, "{}")
if err != nil {
return err
}
}
return nil
})
return err
}
// UpdateGlobalTaskAndAddSubTasks update the global task and add new subtasks
func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtasks []*proto.Subtask, prevState string) (bool, error) {
retryable := true
@ -618,7 +639,6 @@ func (stm *TaskManager) UpdateGlobalTaskAndAddSubTasks(gTask *proto.Task, subtas
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() == 0 {
retryable = false
return errors.New("invalid task state transform, state already changed")

View File

@ -191,8 +191,8 @@ func (dsp *ImportDispatcherExt) unregisterTask(ctx context.Context, task *proto.
}
}
// OnNextStage implements dispatcher.Extension interface.
func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Task) (
// OnNextSubtasksBatch generate batch of next stage's plan.
func (dsp *ImportDispatcherExt) OnNextSubtasksBatch(ctx context.Context, taskHandle dispatcher.TaskHandle, gTask *proto.Task) (
resSubtaskMeta [][]byte, err error) {
logger := logutil.BgLogger().With(
zap.String("type", gTask.Type),
@ -204,18 +204,18 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch
if err != nil {
return nil, err
}
logger.Info("on next stage")
logger.Info("on next subtasks batch")
defer func() {
// currently, framework will take the task as finished when err is not nil or resSubtaskMeta is empty.
taskFinished := err == nil && len(resSubtaskMeta) == 0
if taskFinished {
// todo: we're not running in a transaction with task update
if err2 := dsp.finishJob(ctx, logger, handle, gTask, taskMeta); err2 != nil {
if err2 := dsp.finishJob(ctx, logger, taskHandle, gTask, taskMeta); err2 != nil {
err = err2
}
} else if err != nil && !dsp.IsRetryableErr(err) {
if err2 := dsp.failJob(ctx, handle, gTask, taskMeta, logger, err.Error()); err2 != nil {
if err2 := dsp.failJob(ctx, taskHandle, gTask, taskMeta, logger, err.Error()); err2 != nil {
// todo: we're not running in a transaction with task update, there might be case
// failJob return error, but task update succeed.
logger.Error("call failJob failed", zap.Error(err2))
@ -223,22 +223,18 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch
}
}()
var nextStep int64
switch gTask.Step {
case proto.StepInit:
case StepImport:
if metrics, ok := metric.GetCommonMetric(ctx); ok {
metrics.BytesCounter.WithLabelValues(metric.StateTotalRestore).Add(float64(taskMeta.Plan.TotalFileSize))
}
if err := preProcess(ctx, handle, gTask, taskMeta, logger); err != nil {
if err := preProcess(ctx, taskHandle, gTask, taskMeta, logger); err != nil {
return nil, err
}
if err = startJob(ctx, logger, handle, taskMeta); err != nil {
if err = startJob(ctx, logger, taskHandle, taskMeta); err != nil {
return nil, err
}
logger.Info("move to import step")
nextStep = StepImport
case StepImport:
case StepPostProcess:
dsp.switchTiKV2NormalMode(ctx, gTask, logger)
failpoint.Inject("clearLastSwitchTime", func() {
dsp.lastSwitchTime.Store(time.Time{})
@ -249,18 +245,20 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch
failpoint.Inject("failWhenDispatchPostProcessSubtask", func() {
failpoint.Return(nil, errors.New("injected error after StepImport"))
})
if err := updateResult(handle, gTask, taskMeta); err != nil {
if err := updateResult(taskHandle, gTask, taskMeta); err != nil {
return nil, err
}
if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
return nil, err
}
logger.Info("move to post-process step ", zap.Any("result", taskMeta.Result))
nextStep = StepPostProcess
case StepPostProcess:
case StepPostProcess + 1:
return nil, nil
default:
return nil, errors.Errorf("unknown step %d", gTask.Step)
}
previousSubtaskMetas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step)
previousSubtaskMetas, err := taskHandle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step-1)
if err != nil {
return nil, err
}
@ -273,12 +271,11 @@ func (dsp *ImportDispatcherExt) OnNextStage(ctx context.Context, handle dispatch
if err != nil {
return nil, err
}
metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, nextStep)
metaBytes, err := physicalPlan.ToSubtaskMetas(planCtx, gTask.Step)
if err != nil {
return nil, err
}
gTask.Step = nextStep
logger.Info("generate subtasks", zap.Int64("step", nextStep), zap.Int("subtask-count", len(metaBytes)))
logger.Info("generate subtasks", zap.Int("subtask-count", len(metaBytes)))
return metaBytes, nil
}
@ -340,6 +337,16 @@ func (*ImportDispatcherExt) IsRetryableErr(error) bool {
return false
}
// StageFinished check if current stage finished.
func (*ImportDispatcherExt) StageFinished(_ *proto.Task) bool {
return true
}
// Finished check if current task finished.
func (*ImportDispatcherExt) Finished(task *proto.Task) bool {
return task.Step == StepPostProcess+1
}
func (dsp *ImportDispatcherExt) switchTiKV2NormalMode(ctx context.Context, task *proto.Task, logger *zap.Logger) {
dsp.updateCurrentTask(task)
if dsp.disableTiKVImportMode.Load() {
@ -475,6 +482,7 @@ func updateMeta(gTask *proto.Task, taskMeta *TaskMeta) error {
return err
}
gTask.Meta = bs
return nil
}
@ -492,7 +500,7 @@ func toChunkMap(engineCheckpoints map[int32]*checkpoints.EngineCheckpoint) map[i
// we will update taskMeta in place and make gTask.Meta point to the new taskMeta.
func updateResult(handle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error {
metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step)
metas, err := handle.GetPreviousSubtaskMetas(gTask.ID, gTask.Step-1)
if err != nil {
return err
}
@ -564,6 +572,9 @@ func (dsp *ImportDispatcherExt) finishJob(ctx context.Context, logger *zap.Logge
taskHandle dispatcher.TaskHandle, gTask *proto.Task, taskMeta *TaskMeta) error {
dsp.unregisterTask(ctx, gTask)
redactSensitiveInfo(gTask, taskMeta)
if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
return err
}
summary := &importer.JobSummary{ImportedRows: taskMeta.Result.LoadedRowCnt}
// retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes
backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval)
@ -582,6 +593,9 @@ func (dsp *ImportDispatcherExt) failJob(ctx context.Context, taskHandle dispatch
dsp.switchTiKV2NormalMode(ctx, gTask, logger)
dsp.unregisterTask(ctx, gTask)
redactSensitiveInfo(gTask, taskMeta)
if err := taskHandle.UpdateTask(gTask.State, nil, dispatcher.RetrySQLTimes); err != nil {
return err
}
// retry for 3+6+12+24+(30-4)*30 ~= 825s ~= 14 minutes
backoffer := backoff.NewExponential(dispatcher.RetrySQLInterval, 2, dispatcher.RetrySQLMaxInterval)
return handle.RunWithRetry(ctx, dispatcher.RetrySQLTimes, backoffer, logger,
@ -632,8 +646,6 @@ func rollback(ctx context.Context, handle dispatcher.TaskHandle, gTask *proto.Ta
func stepStr(step int64) string {
switch step {
case proto.StepInit:
return "init"
case StepImport:
return "import"
case StepPostProcess:

View File

@ -89,7 +89,7 @@ func TestDispatcherExt(t *testing.T) {
// to import stage, job should be running
d := dsp.MockDispatcher(task)
ext := importinto.ImportDispatcherExt{}
subtaskMetas, err := ext.OnNextStage(ctx, d, task)
subtaskMetas, err := ext.OnNextSubtasksBatch(ctx, d, task)
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
require.Equal(t, importinto.StepImport, task.Step)
@ -109,7 +109,8 @@ func TestDispatcherExt(t *testing.T) {
require.NoError(t, manager.FinishSubtask(s.ID, []byte("{}")))
}
// to post-process stage, job should be running and in validating step
subtaskMetas, err = ext.OnNextStage(ctx, d, task)
task.Step++
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task)
require.NoError(t, err)
require.Len(t, subtaskMetas, 1)
require.Equal(t, importinto.StepPostProcess, task.Step)
@ -118,7 +119,8 @@ func TestDispatcherExt(t *testing.T) {
require.Equal(t, "running", gotJobInfo.Status)
require.Equal(t, "validating", gotJobInfo.Step)
// on next stage, job should be finished
subtaskMetas, err = ext.OnNextStage(ctx, d, task)
task.Step++
subtaskMetas, err = ext.OnNextSubtasksBatch(ctx, d, task)
require.NoError(t, err)
require.Len(t, subtaskMetas, 0)
gotJobInfo, err = importer.GetJob(ctx, conn, jobID, "root", true)
@ -133,7 +135,8 @@ func TestDispatcherExt(t *testing.T) {
bs, err = logicalPlan.ToTaskMeta()
require.NoError(t, err)
task.Meta = bs
task.Step = proto.StepInit
// Set step to StepPostProcess to skip the rollback sql.
task.Step = importinto.StepPostProcess
require.NoError(t, importer.StartJob(ctx, conn, jobID))
_, err = ext.OnErrStage(ctx, d, task, []error{errors.New("test")})
require.NoError(t, err)

View File

@ -33,10 +33,10 @@ import (
// steps are processed in the following order: StepInit -> StepImport -> StepPostProcess
const (
// StepImport we sort source data and ingest it into TiKV in this step.
StepImport int64 = 1
StepImport int64 = 0
// StepPostProcess we verify checksum and add index in this step.
// TODO: Might split into StepValidate and StepAddIndex later.
StepPostProcess int64 = 2
StepPostProcess int64 = 1
)
// TaskMeta is the task of IMPORT INTO.