Files
tidb/pkg/executor/aggregate/agg_spill.go
2025-12-10 09:40:20 +00:00

383 lines
11 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package aggregate
import (
"sync"
"sync/atomic"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/executor/aggfuncs"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/disk"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/memory"
"go.uber.org/zap"
)
type spillStatus int32
const (
noSpill spillStatus = iota
needSpill
inSpilling
spillTriggered
)
const (
// maxSpillTimes indicates how many times the data can spill at most.
maxSpillTimes = 10
spilledPartitionNum = 256
spillLogInfo string = "memory exceeds quota, set aggregate mode to spill-mode"
)
type parallelHashAggSpillHelper struct {
lock struct {
*sync.Mutex
waitIfInSpilling *sync.Cond
nextPartitionIdx int
spilledChunksIO [][]*chunk.DataInDiskByChunks
status spillStatus
memoryConsumption int64
memoryQuota int64
}
memTracker *memory.Tracker
diskTracker *disk.Tracker
hasError atomic.Bool
// These agg functions are partial agg functions that are same with partial workers'.
// They only be used for restoring data that are spilled to disk in partial stage.
aggFuncsForRestoring []aggfuncs.AggFunc
finalWorkerAggFuncs []aggfuncs.AggFunc
getNewSpillChunkFunc func() *chunk.Chunk
spillChunkFieldTypes []*types.FieldType
}
func newSpillHelper(
tracker *memory.Tracker,
aggFuncsForRestoring []aggfuncs.AggFunc,
finalWorkerAggFuncs []aggfuncs.AggFunc,
getNewSpillChunkFunc func() *chunk.Chunk,
spillChunkFieldTypes []*types.FieldType) (*parallelHashAggSpillHelper, error) {
if len(aggFuncsForRestoring) != len(finalWorkerAggFuncs) {
return nil, errors.NewNoStackError("len(aggFuncsForRestoring) != len(finalWorkerAggFuncs)")
}
mu := new(sync.Mutex)
helper := &parallelHashAggSpillHelper{
lock: struct {
*sync.Mutex
waitIfInSpilling *sync.Cond
nextPartitionIdx int
spilledChunksIO [][]*chunk.DataInDiskByChunks
status spillStatus
memoryConsumption int64
memoryQuota int64
}{
Mutex: mu,
waitIfInSpilling: sync.NewCond(mu),
spilledChunksIO: make([][]*chunk.DataInDiskByChunks, spilledPartitionNum),
status: noSpill,
nextPartitionIdx: spilledPartitionNum - 1,
memoryConsumption: 0,
memoryQuota: 0,
},
memTracker: tracker,
hasError: atomic.Bool{},
aggFuncsForRestoring: aggFuncsForRestoring,
finalWorkerAggFuncs: finalWorkerAggFuncs,
getNewSpillChunkFunc: getNewSpillChunkFunc,
spillChunkFieldTypes: spillChunkFieldTypes,
}
return helper, nil
}
func (p *parallelHashAggSpillHelper) close() {
p.lock.Lock()
defer p.lock.Unlock()
for _, ios := range p.lock.spilledChunksIO {
for _, io := range ios {
io.Close()
}
}
}
func (p *parallelHashAggSpillHelper) isSpilledChunksIOEmpty() bool {
p.lock.Lock()
defer p.lock.Unlock()
for i := range p.lock.spilledChunksIO {
if len(p.lock.spilledChunksIO[i]) > 0 {
return false
}
}
return true
}
func (p *parallelHashAggSpillHelper) getNextPartition() (int, bool) {
p.lock.Lock()
defer p.lock.Unlock()
partitionIdx := p.lock.nextPartitionIdx
if partitionIdx < 0 {
return -1, false
}
p.lock.nextPartitionIdx--
return partitionIdx, true
}
func (p *parallelHashAggSpillHelper) addListInDisks(dataInDisk []*chunk.DataInDiskByChunks) {
p.lock.Lock()
defer p.lock.Unlock()
for i, data := range dataInDisk {
p.lock.spilledChunksIO[i] = append(p.lock.spilledChunksIO[i], data)
}
}
func (p *parallelHashAggSpillHelper) getListInDisks(partitionNum int) []*chunk.DataInDiskByChunks {
p.lock.Lock()
defer p.lock.Unlock()
return p.lock.spilledChunksIO[partitionNum]
}
func (p *parallelHashAggSpillHelper) setInSpilling() {
p.lock.Lock()
defer p.lock.Unlock()
p.lock.status = inSpilling
logutil.BgLogger().Info(spillLogInfo,
zap.Int64("consumed", p.lock.memoryConsumption),
zap.Int64("quota", p.lock.memoryQuota))
}
func (p *parallelHashAggSpillHelper) isNoSpill() bool {
p.lock.Lock()
defer p.lock.Unlock()
return p.lock.status == noSpill
}
func (p *parallelHashAggSpillHelper) setSpillTriggered() {
p.lock.Lock()
defer p.lock.Unlock()
p.lock.status = spillTriggered
p.lock.waitIfInSpilling.Broadcast()
}
func (p *parallelHashAggSpillHelper) checkNeedSpill() bool {
p.lock.Lock()
defer p.lock.Unlock()
return p.lock.status == needSpill
}
// Return true if we successfully set flag
func (p *parallelHashAggSpillHelper) setNeedSpill(executorTracker *memory.Tracker, triggeredTracker *memory.Tracker) bool {
p.lock.Lock()
defer p.lock.Unlock()
if hasEnoughDataToSpill(executorTracker, triggeredTracker) {
p.lock.status = needSpill
p.lock.memoryConsumption = triggeredTracker.BytesConsumed()
p.lock.memoryQuota = triggeredTracker.GetBytesLimit()
return true
}
return false
}
func (p *parallelHashAggSpillHelper) waitForTheEndOfSpill() {
p.lock.Lock()
defer p.lock.Unlock()
for p.lock.status == inSpilling {
p.lock.waitIfInSpilling.Wait()
}
}
// We need to check error with atmoic as multi partial workers may access it.
func (p *parallelHashAggSpillHelper) checkError() bool {
return p.hasError.Load()
}
// We need to set error with atmoic as multi partial workers may access it.
func (p *parallelHashAggSpillHelper) setError() {
p.hasError.Store(true)
}
func (p *parallelHashAggSpillHelper) restoreOnePartition(ctx sessionctx.Context) (aggfuncs.AggPartialResultMapper, int64, error) {
restoredData := aggfuncs.NewAggPartialResultMapper()
restoredMem := int64(0)
restoredPartitionIdx, isSuccess := p.getNextPartition()
if !isSuccess {
return nil, restoredMem, nil
}
spilledFilesIO := p.getListInDisks(restoredPartitionIdx)
for _, spilledFile := range spilledFilesIO {
memDelta, expandMem, err := p.restoreFromOneSpillFile(ctx, restoredData, spilledFile)
if err != nil {
return nil, restoredMem, err
}
p.memTracker.Consume(memDelta)
restoredMem += memDelta + expandMem
}
return restoredData, restoredMem, nil
}
type processRowContext struct {
ctx sessionctx.Context
chunk *chunk.Chunk
rowPos int
keyColPos int
aggFuncNum int
restoreadData aggfuncs.AggPartialResultMapper
partialResultsRestored [][]aggfuncs.PartialResult
}
func (p *parallelHashAggSpillHelper) restoreFromOneSpillFile(ctx sessionctx.Context, restoreadData aggfuncs.AggPartialResultMapper, diskIO *chunk.DataInDiskByChunks) (totalMemDelta int64, totalExpandMem int64, err error) {
chunkNum := diskIO.NumChunks()
aggFuncNum := len(p.aggFuncsForRestoring)
processRowContext := &processRowContext{
ctx: ctx,
chunk: nil, // Will be set in the loop
rowPos: 0, // Will be set in the loop
keyColPos: aggFuncNum,
aggFuncNum: aggFuncNum,
restoreadData: restoreadData,
partialResultsRestored: make([][]aggfuncs.PartialResult, aggFuncNum),
}
for i := range chunkNum {
chunk, err := diskIO.GetChunk(i)
if err != nil {
return totalMemDelta, totalExpandMem, err
}
// Deserialize bytes to agg function's meta data
for aggPos, aggFunc := range p.aggFuncsForRestoring {
partialResult, memDelta := aggFunc.DeserializePartialResult(chunk)
processRowContext.partialResultsRestored[aggPos] = partialResult
totalMemDelta += memDelta
}
// Merge or create results
rowNum := chunk.NumRows()
processRowContext.chunk = chunk
for rowPos := range rowNum {
processRowContext.rowPos = rowPos
memDelta, expandMem, err := p.processRow(processRowContext)
if err != nil {
return totalMemDelta, totalExpandMem, err
}
totalMemDelta += memDelta
totalExpandMem += expandMem
}
}
return totalMemDelta, totalExpandMem, nil
}
func (p *parallelHashAggSpillHelper) processRow(context *processRowContext) (totalMemDelta int64, expandMem int64, err error) {
key := context.chunk.GetRow(context.rowPos).GetString(context.keyColPos)
prs, ok := context.restoreadData.M[key]
if ok {
exprCtx := context.ctx.GetExprCtx()
// The key has appeared before, merge results.
for aggPos := range context.aggFuncNum {
memDelta, err := p.finalWorkerAggFuncs[aggPos].MergePartialResult(exprCtx.GetEvalCtx(), context.partialResultsRestored[aggPos][context.rowPos], prs[aggPos])
if err != nil {
return totalMemDelta, 0, err
}
totalMemDelta += memDelta
}
} else {
totalMemDelta += int64(len(key))
results := make([]aggfuncs.PartialResult, context.aggFuncNum)
delta := context.restoreadData.Set(key, results)
if delta > 0 {
p.memTracker.Consume(delta)
}
for aggPos := range context.aggFuncNum {
results[aggPos] = context.partialResultsRestored[aggPos][context.rowPos]
}
}
return totalMemDelta, expandMem, nil
}
// Guarantee that processed data is at least 20% of the threshold, to avoid spilling too frequently.
func hasEnoughDataToSpill(aggTracker *memory.Tracker, passedInTracker *memory.Tracker) bool {
return aggTracker.BytesConsumed() >= passedInTracker.GetBytesLimit()/5
}
// AggSpillDiskAction implements memory.ActionOnExceed for unparalleled HashAgg.
// If the memory quota of a query is exceeded, AggSpillDiskAction.Action is
// triggered.
type AggSpillDiskAction struct {
memory.BaseOOMAction
e *HashAggExec
spillTimes uint32
}
// Action set HashAggExec spill mode.
func (a *AggSpillDiskAction) Action(t *memory.Tracker) {
if atomic.LoadUint32(&a.e.inSpillMode) == 0 && hasEnoughDataToSpill(a.e.memTracker, t) && a.spillTimes < maxSpillTimes {
a.spillTimes++
logutil.BgLogger().Info(spillLogInfo,
zap.Uint32("spillTimes", a.spillTimes),
zap.Int64("consumed", t.BytesConsumed()),
zap.Int64("quota", t.GetBytesLimit()))
atomic.StoreUint32(&a.e.inSpillMode, 1)
memory.QueryForceDisk.Add(1)
return
}
if fallback := a.GetFallback(); fallback != nil {
fallback.Action(t)
}
}
// GetPriority get the priority of the Action
func (*AggSpillDiskAction) GetPriority() int64 {
return memory.DefSpillPriority
}
// ParallelAggSpillDiskAction implements memory.ActionOnExceed for parallel HashAgg.
type ParallelAggSpillDiskAction struct {
memory.BaseOOMAction
e *HashAggExec
spillHelper *parallelHashAggSpillHelper
}
// Action set HashAggExec spill mode.
func (p *ParallelAggSpillDiskAction) Action(t *memory.Tracker) {
if p.actionImpl(t) {
return
}
p.TriggerFallBackAction(t)
}
// Return true if we successfully set flag
func (p *ParallelAggSpillDiskAction) actionImpl(t *memory.Tracker) bool {
p.spillHelper.waitForTheEndOfSpill()
return p.spillHelper.setNeedSpill(p.e.memTracker, t)
}
// GetPriority get the priority of the Action
func (*ParallelAggSpillDiskAction) GetPriority() int64 {
return memory.DefSpillPriority
}