Files
tidb/pkg/dxf/framework/storage/task_state.go

296 lines
9.6 KiB
Go

// Copyright 2024 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"encoding/json"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/dxf/framework/proto"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/util/injectfailpoint"
"github.com/pingcap/tidb/pkg/util/sqlexec"
)
// CancelTask cancels task.
func (mgr *TaskManager) CancelTask(ctx context.Context, taskID int64) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP()
where id = %? and state in (%?, %?, %?)`,
proto.TaskStateCancelling, taskID, proto.TaskStatePending, proto.TaskStateRunning,
proto.TaskStateAwaitingResolution,
)
return err
}
// CancelTaskByKeySession cancels task by key using input session.
func (*TaskManager) CancelTaskByKeySession(ctx context.Context, se sessionctx.Context, taskKey string) error {
_, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(),
`update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP()
where task_key = %? and state in (%?, %?, %?)`,
proto.TaskStateCancelling, taskKey, proto.TaskStatePending, proto.TaskStateRunning,
proto.TaskStateAwaitingResolution,
)
return err
}
// FailTask implements the scheduler.TaskManager interface.
func (mgr *TaskManager) FailTask(ctx context.Context, taskID int64, currentState proto.TaskState, taskErr error) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_global_task
set state = %?,
error = %?,
state_update_time = CURRENT_TIMESTAMP(),
end_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
proto.TaskStateFailed, serializeErr(taskErr), taskID, currentState,
)
return err
}
// RevertTask implements the scheduler.TaskManager interface.
func (mgr *TaskManager) RevertTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
return mgr.transitTaskStateOnErr(ctx, taskID, taskState, proto.TaskStateReverting, taskErr)
}
func (mgr *TaskManager) transitTaskStateOnErr(ctx context.Context, taskID int64, currState, targetState proto.TaskState, taskErr error) error {
_, err := mgr.ExecuteSQLWithNewSession(ctx, `
update mysql.tidb_global_task
set state = %?,
error = %?,
state_update_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
targetState, serializeErr(taskErr), taskID, currState,
)
return err
}
// AwaitingResolveTask implements the scheduler.TaskManager interface.
func (mgr *TaskManager) AwaitingResolveTask(ctx context.Context, taskID int64, taskState proto.TaskState, taskErr error) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
return mgr.transitTaskStateOnErr(ctx, taskID, taskState, proto.TaskStateAwaitingResolution, taskErr)
}
// RevertedTask implements the scheduler.TaskManager interface.
func (mgr *TaskManager) RevertedTask(ctx context.Context, taskID int64) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP(),
end_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
proto.TaskStateReverted, taskID, proto.TaskStateReverting,
)
return err
}
// PauseTask pauses the task.
func (mgr *TaskManager) PauseTask(ctx context.Context, taskKey string) (bool, error) {
found := false
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return false, err
}
err := mgr.WithNewSession(func(se sessionctx.Context) error {
_, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(),
`update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP()
where task_key = %? and state in (%?, %?)`,
proto.TaskStatePausing, taskKey, proto.TaskStatePending, proto.TaskStateRunning,
)
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() != 0 {
found = true
}
return err
})
if err != nil {
return found, err
}
return found, nil
}
// PausedTask update the task state from pausing to paused.
func (mgr *TaskManager) PausedTask(ctx context.Context, taskID int64) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
_, err := mgr.ExecuteSQLWithNewSession(ctx,
`update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
proto.TaskStatePaused, taskID, proto.TaskStatePausing,
)
return err
}
// ResumeTask resumes the task.
func (mgr *TaskManager) ResumeTask(ctx context.Context, taskKey string) (bool, error) {
found := false
err := mgr.WithNewSession(func(se sessionctx.Context) error {
_, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(),
`update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP()
where task_key = %? and state = %?`,
proto.TaskStateResuming, taskKey, proto.TaskStatePaused,
)
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() != 0 {
found = true
}
return err
})
if err != nil {
return found, err
}
return found, nil
}
// ResumedTask implements the scheduler.TaskManager interface.
func (mgr *TaskManager) ResumedTask(ctx context.Context, taskID int64) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
_, err := mgr.ExecuteSQLWithNewSession(ctx, `
update mysql.tidb_global_task
set state = %?,
state_update_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
proto.TaskStateRunning, taskID, proto.TaskStateResuming,
)
return err
}
// ModifyTaskByID modifies the task by the task ID.
func (mgr *TaskManager) ModifyTaskByID(ctx context.Context, taskID int64, param *proto.ModifyParam) error {
if !param.PrevState.CanMoveToModifying() {
return ErrTaskStateNotAllow
}
bytes, err := json.Marshal(param)
if err != nil {
return errors.Trace(err)
}
return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
task, err2 := mgr.getTaskBaseByID(ctx, se.GetSQLExecutor(), taskID)
if err2 != nil {
return err2
}
if task.State != param.PrevState {
return ErrTaskChanged
}
failpoint.InjectCall("beforeMoveToModifying")
_, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), `
update mysql.tidb_global_task
set state = %?, modify_params = %?, state_update_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
proto.TaskStateModifying, json.RawMessage(bytes), taskID, param.PrevState,
)
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() == 0 {
// the txn is pessimistic, it's possible that another txn has
// changed the task state before this txn commits and there is no
// write-conflict.
return ErrTaskChanged
}
return nil
})
}
// ModifiedTask implements the scheduler.TaskManager interface.
func (mgr *TaskManager) ModifiedTask(ctx context.Context, task *proto.Task) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
prevState := task.ModifyParam.PrevState
return mgr.WithNewTxn(ctx, func(se sessionctx.Context) error {
failpoint.InjectCall("beforeModifiedTask")
_, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), `
update mysql.tidb_global_task
set state = %?,
concurrency = %?,
max_node_count = %?,
meta = %?,
modify_params = null,
state_update_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
prevState, task.RequiredSlots, task.MaxNodeCount, task.Meta, task.ID, proto.TaskStateModifying,
)
if err != nil {
return err
}
if se.GetSessionVars().StmtCtx.AffectedRows() == 0 {
// might be handled by other owner nodes, skip.
return nil
}
// subtask in final state are not changed.
// subtask might have different concurrency later, see TaskExecInfo, we
// need to handle it too, but ok for now.
_, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), `
update mysql.tidb_background_subtask
set concurrency = %?, state_update_time = unix_timestamp()
where task_key = %? and state in (%?, %?, %?)`,
task.RequiredSlots, task.ID,
proto.SubtaskStatePending, proto.SubtaskStateRunning, proto.SubtaskStatePaused)
return err
})
}
// SucceedTask update task state from running to succeed.
func (mgr *TaskManager) SucceedTask(ctx context.Context, taskID int64) error {
if err := injectfailpoint.DXFRandomErrorWithOnePercent(); err != nil {
return err
}
return mgr.WithNewSession(func(se sessionctx.Context) error {
_, err := sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), `
update mysql.tidb_global_task
set state = %?,
step = %?,
state_update_time = CURRENT_TIMESTAMP(),
end_time = CURRENT_TIMESTAMP()
where id = %? and state = %?`,
proto.TaskStateSucceed, proto.StepDone, taskID, proto.TaskStateRunning,
)
return err
})
}