Files
tidb/pkg/resourcemanager/pool/workerpool/workerpool.go

380 lines
8.9 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package workerpool
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/resourcemanager/util"
tidbutil "github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/syncutil"
atomicutil "go.uber.org/atomic"
"go.uber.org/zap"
)
// Context is the context used for worker pool
type Context struct {
context.Context
cancel context.CancelFunc
firstErr atomic.Pointer[error]
}
// OnError store the error and cancel the context.
// If the error is already set, it will not overwrite it.
func (ctx *Context) OnError(err error) {
ctx.firstErr.CompareAndSwap(nil, &err)
logutil.BgLogger().Error("worker pool encountered error", zap.Error(err))
ctx.cancel()
}
// OperatorErr returns the error caused by business logic.
func (ctx *Context) OperatorErr() error {
err := ctx.firstErr.Load()
if err == nil {
return nil
}
return *err
}
// Cancel cancels the context of the operator.
func (ctx *Context) Cancel() {
ctx.cancel()
}
// NewContext creates a new Context
func NewContext(
ctx context.Context,
) *Context {
cctx, cancel := context.WithCancel(ctx)
return &Context{
Context: cctx,
cancel: cancel,
}
}
// TaskMayPanic is a type to remind the developer that need to handle panic in
// the task.
type TaskMayPanic interface {
// RecoverArgs returns the argument for pkg/util.Recover function of this task.
// The error returned is which will be passed to upper level, if not provided,
// we will use the default error.
RecoverArgs() (metricsLabel string, funcInfo string, err error)
}
// Worker is worker interface.
type Worker[T TaskMayPanic, R any] interface {
// HandleTask consumes a task(T), either produces a result(R) or return an error.
// The result is sent to the result channel by calling `send` function, and the
// error returned will be catched, log, and broadcasted it to other operators
// by worker pool.
// TODO(joechenrh): we can pass the context to HandleTask, so we don't need to
// store the context in each worker implementation.
HandleTask(task T, send func(R)) error
Close() error
}
type tuneConfig struct {
wg *sync.WaitGroup
}
// Tuner is an interface that provides capacity for tuning
// the worker pools. It's used to pass worker pool without import cycle
// caused by generic type.
type Tuner interface {
Tune(numWorkers int32, wait bool)
}
// WorkerPool is a pool of workers.
type WorkerPool[T TaskMayPanic, R any] struct {
// wctx are the context used for the whole pipeline, and ctx and cancel are derived
// from wctx, to notify all workers to quit when the tasks are done or any error occurs.
wctx *Context
ctx context.Context
cancel context.CancelFunc
name string
numWorkers int32
originWorkers int32
runningTask atomicutil.Int32
taskChan chan T
resChan chan R
quitChan chan tuneConfig
wg tidbutil.WaitGroupWrapper
createWorker func() Worker[T, R]
lastTuneTs atomicutil.Time
started atomic.Bool
mu syncutil.RWMutex
}
// Option is the config option for WorkerPool.
type Option[T TaskMayPanic, R any] interface {
Apply(pool *WorkerPool[T, R])
}
// None is a type placeholder for the worker pool that does not have a result receiver.
type None struct{}
// NewWorkerPool creates a new worker pool.
func NewWorkerPool[T TaskMayPanic, R any](
name string,
_ util.Component,
numWorkers int,
createWorker func() Worker[T, R],
opts ...Option[T, R],
) *WorkerPool[T, R] {
if numWorkers <= 0 {
numWorkers = 1
}
failpoint.InjectCall("NewWorkerPool", numWorkers)
p := &WorkerPool[T, R]{
name: name,
numWorkers: int32(numWorkers),
originWorkers: int32(numWorkers),
quitChan: make(chan tuneConfig),
}
for _, opt := range opts {
opt.Apply(p)
}
p.createWorker = createWorker
return p
}
// SetTaskReceiver sets the task receiver for the pool.
func (p *WorkerPool[T, R]) SetTaskReceiver(recv chan T) {
p.taskChan = recv
}
// SetResultSender sets the result sender for the pool.
func (p *WorkerPool[T, R]) SetResultSender(sender chan R) {
p.resChan = sender
}
// Start starts default count of workers.
func (p *WorkerPool[T, R]) Start(ctx *Context) {
if p.taskChan == nil {
p.taskChan = make(chan T)
}
if p.resChan == nil {
var zero R
var r any = zero
if _, ok := r.(None); !ok {
p.resChan = make(chan R)
}
}
p.wctx = ctx
p.ctx, p.cancel = context.WithCancel(ctx)
p.mu.Lock()
defer p.mu.Unlock()
for range p.numWorkers {
p.runAWorker()
}
p.started.Store(true)
}
func (p *WorkerPool[T, R]) handleTaskWithRecover(w Worker[T, R], task T) {
p.runningTask.Add(1)
defer func() {
p.runningTask.Add(-1)
}()
label, funcInfo, err := task.RecoverArgs()
recoverFn := func() {
if err != nil {
p.wctx.OnError(err)
} else {
p.wctx.OnError(errors.Errorf("task panic: %s, func info: %s", label, funcInfo))
}
}
defer tidbutil.Recover(label, funcInfo, recoverFn, false)
sendResult := func(r R) {
if p.resChan == nil {
return
}
select {
case p.resChan <- r:
case <-p.ctx.Done():
}
}
if err := w.HandleTask(task, sendResult); err != nil {
p.wctx.OnError(err)
}
}
func (p *WorkerPool[T, R]) runAWorker() {
w := p.createWorker()
if w == nil {
return // Fail to create worker, quit.
}
p.wg.Run(func() {
var err error
defer func() {
if err != nil {
p.wctx.OnError(err)
}
}()
for {
select {
case task, ok := <-p.taskChan:
if !ok {
err = w.Close()
return
}
p.handleTaskWithRecover(w, task)
case cfg, ok := <-p.quitChan:
err = w.Close()
if ok {
cfg.wg.Done()
}
return
case <-p.ctx.Done():
err = w.Close()
return
}
}
})
}
// AddTask adds a task to the pool, only used in test.
func (p *WorkerPool[T, R]) AddTask(task T) {
select {
case <-p.ctx.Done():
case p.taskChan <- task:
}
}
// GetResultChan gets the result channel from the pool.
func (p *WorkerPool[T, R]) GetResultChan() <-chan R {
return p.resChan
}
// Tune tunes the pool to the specified number of workers.
// wait: whether to wait for all workers to close when reducing workers count.
// this method can only be called after Start.
func (p *WorkerPool[T, R]) Tune(numWorkers int32, wait bool) {
if numWorkers <= 0 {
numWorkers = 1
}
p.lastTuneTs.Store(time.Now())
p.mu.Lock()
defer p.mu.Unlock()
logutil.BgLogger().Info("tune worker pool",
zap.Int32("from", p.numWorkers), zap.Int32("to", numWorkers))
// If the pool is not started, just set the number of workers.
if !p.started.Load() {
p.numWorkers = numWorkers
return
}
diff := numWorkers - p.numWorkers
if diff > 0 {
// Add workers
for range diff {
p.runAWorker()
}
} else if diff < 0 {
// Remove workers
var wg sync.WaitGroup
outer:
for range -diff {
wg.Add(1)
select {
case p.quitChan <- tuneConfig{wg: &wg}:
case <-p.ctx.Done():
logutil.BgLogger().Info("context done when tuning worker pool",
zap.Int32("from", p.numWorkers), zap.Int32("to", numWorkers))
wg.Done()
break outer
}
}
if wait {
wg.Wait()
}
}
p.numWorkers = numWorkers
}
// LastTunerTs returns the last time when the pool was tuned.
func (p *WorkerPool[T, R]) LastTunerTs() time.Time {
return p.lastTuneTs.Load()
}
// Cap returns the capacity of the pool.
func (p *WorkerPool[T, R]) Cap() int32 {
p.mu.RLock()
defer p.mu.RUnlock()
return p.numWorkers
}
// Running returns the number of running workers.
func (p *WorkerPool[T, R]) Running() int32 {
return p.runningTask.Load()
}
// Name returns the name of the pool.
func (p *WorkerPool[T, R]) Name() string {
return p.name
}
// CloseAndWait manually closes the pool and wait for complete, only used in test
func (p *WorkerPool[T, R]) CloseAndWait() {
if p.cancel != nil {
p.cancel()
}
close(p.quitChan)
p.Release()
}
// Release waits the pool to be released.
// It will wait the input channel to be closed,
// or the context being cancelled by business error.
func (p *WorkerPool[T, R]) Release() {
// First, wait waits for all workers to complete.
p.wg.Wait()
// Cancel tuning workers.
if p.cancel != nil {
p.cancel()
}
if p.resChan != nil {
close(p.resChan)
p.resChan = nil
}
}
// GetOriginConcurrency return the concurrency of the pool at the init.
func (p *WorkerPool[T, R]) GetOriginConcurrency() int32 {
return p.originWorkers
}