Files
tidb/executor/aggregate/agg_hash_partial_worker.go
2023-08-30 08:50:10 +00:00

159 lines
4.7 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package aggregate
import (
"sync"
"time"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/stmtctx"
"github.com/pingcap/tidb/util/chunk"
"github.com/twmb/murmur3"
)
// HashAggIntermData indicates the intermediate data of aggregation execution.
type HashAggIntermData struct {
groupKeys []string
cursor int
partialResultMap AggPartialResultMapper
}
// HashAggPartialWorker indicates the partial workers of parallel hash agg execution,
// the number of the worker can be set by `tidb_hashagg_partial_concurrency`.
type HashAggPartialWorker struct {
baseHashAggWorker
inputCh chan *chunk.Chunk
outputChs []chan *HashAggIntermData
globalOutputCh chan *AfFinalResult
giveBackCh chan<- *HashAggInput
partialResultsMap AggPartialResultMapper
groupByItems []expression.Expression
groupKey [][]byte
// chk stores the input data from child,
// and is reused by childExec and partial worker.
chk *chunk.Chunk
}
func (w *HashAggPartialWorker) getChildInput() bool {
select {
case <-w.finishCh:
return false
case chk, ok := <-w.inputCh:
if !ok {
return false
}
w.chk.SwapColumns(chk)
w.giveBackCh <- &HashAggInput{
chk: chk,
giveBackCh: w.inputCh,
}
}
return true
}
func (w *HashAggPartialWorker) run(ctx sessionctx.Context, waitGroup *sync.WaitGroup, finalConcurrency int) {
start := time.Now()
needShuffle, sc := false, ctx.GetSessionVars().StmtCtx
defer func() {
if r := recover(); r != nil {
recoveryHashAgg(w.globalOutputCh, r)
}
if needShuffle {
w.shuffleIntermData(sc, finalConcurrency)
}
w.memTracker.Consume(-w.chk.MemoryUsage())
if w.stats != nil {
w.stats.WorkerTime += int64(time.Since(start))
}
waitGroup.Done()
}()
for {
waitStart := time.Now()
ok := w.getChildInput()
if w.stats != nil {
w.stats.WaitTime += int64(time.Since(waitStart))
}
if !ok {
return
}
execStart := time.Now()
if err := w.updatePartialResult(ctx, sc, w.chk, len(w.partialResultsMap)); err != nil {
w.globalOutputCh <- &AfFinalResult{err: err}
return
}
if w.stats != nil {
w.stats.ExecTime += int64(time.Since(execStart))
w.stats.TaskNum++
}
// The intermData can be promised to be not empty if reaching here,
// so we set needShuffle to be true.
needShuffle = true
}
}
func (w *HashAggPartialWorker) updatePartialResult(ctx sessionctx.Context, sc *stmtctx.StatementContext, chk *chunk.Chunk, _ int) (err error) {
memSize := getGroupKeyMemUsage(w.groupKey)
w.groupKey, err = GetGroupKey(w.ctx, chk, w.groupKey, w.groupByItems)
failpoint.Inject("ConsumeRandomPanic", nil)
w.memTracker.Consume(getGroupKeyMemUsage(w.groupKey) - memSize)
if err != nil {
return err
}
partialResults := w.getPartialResult(sc, w.groupKey, w.partialResultsMap)
numRows := chk.NumRows()
rows := make([]chunk.Row, 1)
allMemDelta := int64(0)
for i := 0; i < numRows; i++ {
for j, af := range w.aggFuncs {
rows[0] = chk.GetRow(i)
memDelta, err := af.UpdatePartialResult(ctx, rows, partialResults[i][j])
if err != nil {
return err
}
allMemDelta += memDelta
}
}
w.memTracker.Consume(allMemDelta)
return nil
}
// shuffleIntermData shuffles the intermediate data of partial workers to corresponded final workers.
// We only support parallel execution for single-machine, so process of encode and decode can be skipped.
func (w *HashAggPartialWorker) shuffleIntermData(_ *stmtctx.StatementContext, finalConcurrency int) {
groupKeysSlice := make([][]string, finalConcurrency)
for groupKey := range w.partialResultsMap {
finalWorkerIdx := int(murmur3.Sum32([]byte(groupKey))) % finalConcurrency
if groupKeysSlice[finalWorkerIdx] == nil {
groupKeysSlice[finalWorkerIdx] = make([]string, 0, len(w.partialResultsMap)/finalConcurrency)
}
groupKeysSlice[finalWorkerIdx] = append(groupKeysSlice[finalWorkerIdx], groupKey)
}
for i := range groupKeysSlice {
if groupKeysSlice[i] == nil {
continue
}
w.outputChs[i] <- &HashAggIntermData{
groupKeys: groupKeysSlice[i],
partialResultMap: w.partialResultsMap,
}
}
}