Files
tidb/pkg/executor/aggregate/agg_spill_test.go
2025-12-16 12:39:53 +00:00

526 lines
16 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package aggregate_test
import (
"context"
"fmt"
"math/rand"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/executor/aggfuncs"
"github.com/pingcap/tidb/pkg/executor/aggregate"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/executor/internal/testutil"
"github.com/pingcap/tidb/pkg/executor/internal/util"
"github.com/pingcap/tidb/pkg/expression"
"github.com/pingcap/tidb/pkg/expression/aggregation"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/sessionctx"
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/mock"
"github.com/stretchr/testify/require"
)
// Chunk schema in this test file: | column0: string | column1: float64 |
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func getRandString() string {
b := make([]byte, 5)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return string(b)
}
func generateData(rowNum int, ndv int) ([]string, []float64) {
keys := make([]string, 0)
for range ndv {
keys = append(keys, getRandString())
}
col0Data := make([]string, 0)
col1Data := make([]float64, 0)
// Generate data
for range rowNum {
key := keys[rand.Intn(ndv)]
col0Data = append(col0Data, key)
col1Data = append(col1Data, float64(rand.Intn(10000000)))
}
// Shuffle data
rand.Shuffle(rowNum, func(i, j int) {
col0Data[i], col0Data[j] = col0Data[j], col0Data[i]
// There is no need to shuffle col2Data as all of it's values are 1.
})
return col0Data, col1Data
}
func buildMockDataSource(opt testutil.MockDataSourceParameters, col0Data []string, col1Data []float64) *testutil.MockDataSource {
baseExec := exec.NewBaseExecutor(opt.Ctx, opt.DataSchema, 0)
mockDatasource := &testutil.MockDataSource{
BaseExecutor: baseExec,
ChunkPtr: 0,
P: opt,
GenData: nil,
Chunks: nil}
maxChunkSize := mockDatasource.MaxChunkSize()
rowNum := len(col0Data)
mockDatasource.GenData = make([]*chunk.Chunk, (rowNum+maxChunkSize-1)/maxChunkSize)
for i := range mockDatasource.GenData {
mockDatasource.GenData[i] = chunk.NewChunkWithCapacity(exec.RetTypes(mockDatasource), maxChunkSize)
}
for i := range rowNum {
chkIdx := i / maxChunkSize
mockDatasource.GenData[chkIdx].AppendString(0, col0Data[i])
mockDatasource.GenData[chkIdx].AppendFloat64(1, col1Data[i])
}
return mockDatasource
}
func generateCMPFunc(fieldTypes []*types.FieldType) func(chunk.Row, chunk.Row) int {
cmpFuncs := make([]chunk.CompareFunc, 0, len(fieldTypes))
for _, colType := range fieldTypes {
cmpFuncs = append(cmpFuncs, chunk.GetCompareFunc(colType))
}
cmp := func(rowI, rowJ chunk.Row) int {
for i, cmpFunc := range cmpFuncs {
cmp := cmpFunc(rowI, i, rowJ, i)
if cmp != 0 {
return cmp
}
}
return 0
}
return cmp
}
func sortRows(rows []chunk.Row, fieldTypes []*types.FieldType) []chunk.Row {
cmp := generateCMPFunc(fieldTypes)
sort.Slice(rows, func(i, j int) bool {
return cmp(rows[i], rows[j]) < 0
})
return rows
}
func generateResult(t *testing.T, ctx *mock.Context, dataSource *testutil.MockDataSource, fileNamePrefixForTest string) []chunk.Row {
aggExec := buildHashAggExecutor(t, ctx, dataSource, fileNamePrefixForTest)
dataSource.PrepareChunks()
tmpCtx := context.Background()
resultRows := make([]chunk.Row, 0)
aggExec.Open(tmpCtx)
for {
chk := exec.NewFirstChunk(aggExec)
err := aggExec.Next(tmpCtx, chk)
require.Equal(t, nil, err)
if chk.NumRows() == 0 {
break
}
rowNum := chk.NumRows()
for i := range rowNum {
resultRows = append(resultRows, chk.GetRow(i))
}
}
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
require.False(t, aggExec.IsSpillTriggeredForTest())
return sortRows(resultRows, getRetTypes())
}
func getRetTypes() []*types.FieldType {
return []*types.FieldType{
types.NewFieldType(mysql.TypeVarString),
types.NewFieldType(mysql.TypeDouble),
types.NewFieldType(mysql.TypeLonglong),
types.NewFieldType(mysql.TypeDouble),
types.NewFieldType(mysql.TypeDouble),
types.NewFieldType(mysql.TypeDouble),
}
}
func getColumns() []*expression.Column {
return []*expression.Column{
{Index: 0, RetType: types.NewFieldType(mysql.TypeVarString)},
{Index: 1, RetType: types.NewFieldType(mysql.TypeDouble)},
{Index: 1, RetType: types.NewFieldType(mysql.TypeDouble)},
{Index: 1, RetType: types.NewFieldType(mysql.TypeDouble)},
{Index: 1, RetType: types.NewFieldType(mysql.TypeDouble)},
{Index: 1, RetType: types.NewFieldType(mysql.TypeDouble)},
}
}
func getSchema() *expression.Schema {
return expression.NewSchema(getColumns()...)
}
func getMockDataSourceParameters(ctx sessionctx.Context) testutil.MockDataSourceParameters {
return testutil.MockDataSourceParameters{
DataSchema: getSchema(),
Ctx: ctx,
}
}
func buildHashAggExecutor(t *testing.T, ctx sessionctx.Context, child exec.Executor, fileNamePrefixForTest string) *aggregate.HashAggExec {
if err := ctx.GetSessionVars().SetSystemVar(vardef.TiDBHashAggFinalConcurrency, fmt.Sprintf("%v", 5)); err != nil {
t.Fatal(err)
}
if err := ctx.GetSessionVars().SetSystemVar(vardef.TiDBHashAggPartialConcurrency, fmt.Sprintf("%v", 5)); err != nil {
t.Fatal(err)
}
childCols := getColumns()
schema := expression.NewSchema(childCols...)
groupItems := []expression.Expression{childCols[0]}
var err error
var aggFirstRow *aggregation.AggFuncDesc
var aggSum *aggregation.AggFuncDesc
var aggCount *aggregation.AggFuncDesc
var aggAvg *aggregation.AggFuncDesc
var aggMin *aggregation.AggFuncDesc
var aggMax *aggregation.AggFuncDesc
aggFirstRow, err = aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{childCols[0]}, false)
if err != nil {
t.Fatal(err)
}
aggSum, err = aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncSum, []expression.Expression{childCols[1]}, false)
if err != nil {
t.Fatal(err)
}
aggCount, err = aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncCount, []expression.Expression{childCols[1]}, false)
if err != nil {
t.Fatal(err)
}
aggAvg, err = aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncAvg, []expression.Expression{childCols[1]}, false)
if err != nil {
t.Fatal(err)
}
aggMin, err = aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncMin, []expression.Expression{childCols[1]}, false)
if err != nil {
t.Fatal(err)
}
aggMax, err = aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncMax, []expression.Expression{childCols[1]}, false)
if err != nil {
t.Fatal(err)
}
aggFuncs := []*aggregation.AggFuncDesc{aggFirstRow, aggSum, aggCount, aggAvg, aggMin, aggMax}
aggExec := &aggregate.HashAggExec{
BaseExecutor: exec.NewBaseExecutor(ctx, schema, 0, child),
Sc: ctx.GetSessionVars().StmtCtx,
PartialAggFuncs: make([]aggfuncs.AggFunc, 0, len(aggFuncs)),
FinalAggFuncs: make([]aggfuncs.AggFunc, 0, len(aggFuncs)),
GroupByItems: groupItems,
IsUnparallelExec: false,
FileNamePrefixForTest: fileNamePrefixForTest,
}
partialOrdinal := 0
for i, aggDesc := range aggFuncs {
ordinal := []int{partialOrdinal}
partialOrdinal++
if aggDesc.Name == ast.AggFuncAvg {
ordinal = append(ordinal, partialOrdinal+1)
partialOrdinal++
}
partialAggDesc, finalDesc := aggDesc.Split(ordinal)
partialAggFunc := aggfuncs.Build(ctx.GetExprCtx(), partialAggDesc, i)
finalAggFunc := aggfuncs.Build(ctx.GetExprCtx(), finalDesc, i)
aggExec.PartialAggFuncs = append(aggExec.PartialAggFuncs, partialAggFunc)
aggExec.FinalAggFuncs = append(aggExec.FinalAggFuncs, finalAggFunc)
}
aggExec.SetChildren(0, child)
return aggExec
}
func initCtx(ctx *mock.Context, newRootExceedAction *testutil.MockActionOnExceed, hardLimitBytesNum int64, chkSize int) {
ctx.GetSessionVars().InitChunkSize = chkSize
ctx.GetSessionVars().MaxChunkSize = chkSize
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSession, hardLimitBytesNum)
ctx.GetSessionVars().TrackAggregateMemoryUsage = true
ctx.GetSessionVars().EnableParallelHashaggSpill = true
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
ctx.GetSessionVars().MemTracker.SetActionOnExceed(newRootExceedAction)
}
func checkResult(expectResult []chunk.Row, actualResult []chunk.Row, retTypes []*types.FieldType) bool {
if len(expectResult) != len(actualResult) {
return false
}
rowNum := len(expectResult)
for i := range rowNum {
if expectResult[i].ToString(retTypes) != actualResult[i].ToString(retTypes) {
return false
}
}
return true
}
func executeCorrecResultTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggExec, dataSource *testutil.MockDataSource, expectResult []chunk.Row, fileNamePrefixForTest string) {
if aggExec == nil {
aggExec = buildHashAggExecutor(t, ctx, dataSource, fileNamePrefixForTest)
}
dataSource.PrepareChunks()
tmpCtx := context.Background()
resultRows := make([]chunk.Row, 0)
aggExec.Open(tmpCtx)
for {
chk := exec.NewFirstChunk(aggExec)
err := aggExec.Next(tmpCtx, chk)
require.Equal(t, nil, err)
if chk.NumRows() == 0 {
break
}
rowNum := chk.NumRows()
for i := range rowNum {
resultRows = append(resultRows, chk.GetRow(i))
}
}
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
require.True(t, aggExec.IsSpillTriggeredForTest())
retTypes := getRetTypes()
resultRows = sortRows(resultRows, retTypes)
require.True(t, checkResult(expectResult, resultRows, retTypes))
}
func fallBackActionTest(t *testing.T, fileNamePrefixForTest string) {
newRootExceedAction := new(testutil.MockActionOnExceed)
hardLimitBytesNum := int64(6000000)
ctx := mock.NewContext()
initCtx(ctx, newRootExceedAction, hardLimitBytesNum, 4096)
// Consume lots of memory in advance to help to trigger fallback action.
ctx.GetSessionVars().MemTracker.Consume(int64(float64(hardLimitBytesNum) * 0.799999))
rowNum := 10000 + rand.Intn(10000)
ndv := 5000 + rand.Intn(5000)
col1, col2 := generateData(rowNum, ndv)
opt := getMockDataSourceParameters(ctx)
dataSource := buildMockDataSource(opt, col1, col2)
aggExec := buildHashAggExecutor(t, ctx, dataSource, fileNamePrefixForTest)
dataSource.PrepareChunks()
tmpCtx := context.Background()
chk := exec.NewFirstChunk(aggExec)
aggExec.Open(tmpCtx)
for {
aggExec.Next(tmpCtx, chk)
if chk.NumRows() == 0 {
break
}
chk.Reset()
}
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
require.Less(t, 0, newRootExceedAction.GetTriggeredNum())
}
func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggExec, dataSource *testutil.MockDataSource, fileNamePrefixForTest string) {
if aggExec == nil {
aggExec = buildHashAggExecutor(t, ctx, dataSource, fileNamePrefixForTest)
}
dataSource.PrepareChunks()
tmpCtx := context.Background()
chk := exec.NewFirstChunk(aggExec)
aggExec.Open(tmpCtx)
goRoutineWaiter := sync.WaitGroup{}
goRoutineWaiter.Add(1)
defer goRoutineWaiter.Wait()
once := sync.Once{}
go func() {
time.Sleep(time.Duration(rand.Int31n(300)) * time.Millisecond)
once.Do(func() {
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
})
goRoutineWaiter.Done()
}()
for {
err := aggExec.Next(tmpCtx, chk)
if err != nil {
once.Do(func() {
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
err = aggExec.Close()
require.Equal(t, nil, err)
})
break
}
if chk.NumRows() == 0 {
break
}
chk.Reset()
}
once.Do(func() {
require.False(t, aggExec.IsInvalidMemoryUsageTrackingForTest())
aggExec.Close()
})
}
// sql: select col0, sum(col1), count(col1), avg(col1), min(col1), max(col1) from t group by t.col0;
func TestGetCorrectResult(t *testing.T) {
defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
conf.TempStoragePath = t.TempDir()
})
testFuncName := util.GetFunctionName()
newRootExceedAction := new(testutil.MockActionOnExceed)
ctx := mock.NewContext()
initCtx(ctx, newRootExceedAction, -1, 1024)
rowNum := 100000
ndv := 50000
col0, col1 := generateData(rowNum, ndv)
opt := getMockDataSourceParameters(ctx)
dataSource := buildMockDataSource(opt, col0, col1)
result := generateResult(t, ctx, dataSource, testFuncName)
err := failpoint.Enable("github.com/pingcap/tidb/pkg/executor/aggregate/slowSomePartialWorkers", `return(true)`)
require.NoError(t, err)
defer require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/aggregate/slowSomePartialWorkers"))
hardLimitBytesNum := int64(6000000)
initCtx(ctx, newRootExceedAction, hardLimitBytesNum, 256)
finished := atomic.Bool{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
tracker := ctx.GetSessionVars().MemTracker
for {
if finished.Load() {
break
}
// Mock consuming in another goroutine, so that we can test potential data race.
tracker.Consume(1)
time.Sleep(1 * time.Millisecond)
}
wg.Done()
}()
aggExec := buildHashAggExecutor(t, ctx, dataSource, testFuncName)
for range 5 {
executeCorrecResultTest(t, ctx, nil, dataSource, result, testFuncName)
executeCorrecResultTest(t, ctx, aggExec, dataSource, result, testFuncName)
}
finished.Store(true)
wg.Wait()
util.CheckNoLeakFiles(t, testFuncName)
}
func TestFallBackAction(t *testing.T) {
defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
conf.TempStoragePath = t.TempDir()
})
testFuncName := util.GetFunctionName()
for range 50 {
fallBackActionTest(t, testFuncName)
}
util.CheckNoLeakFiles(t, testFuncName)
}
func TestRandomFail(t *testing.T) {
defer config.RestoreFunc()()
config.UpdateGlobal(func(conf *config.Config) {
conf.TempStoragePath = t.TempDir()
})
testFuncName := util.GetFunctionName()
newRootExceedAction := new(testutil.MockActionOnExceed)
hardLimitBytesNum := int64(5000000)
ctx := mock.NewContext()
initCtx(ctx, newRootExceedAction, hardLimitBytesNum, 32)
failpoint.Enable("github.com/pingcap/tidb/pkg/executor/aggregate/enableAggSpillIntest", `return(true)`)
defer failpoint.Disable("github.com/pingcap/tidb/pkg/executor/aggregate/enableAggSpillIntest")
failpoint.Enable("github.com/pingcap/tidb/pkg/util/chunk/ChunkInDiskError", `return(true)`)
defer failpoint.Disable("github.com/pingcap/tidb/pkg/util/chunk/ChunkInDiskError")
rowNum := 100000 + rand.Intn(100000)
ndv := 50000 + rand.Intn(50000)
col1, col2 := generateData(rowNum, ndv)
opt := getMockDataSourceParameters(ctx)
dataSource := buildMockDataSource(opt, col1, col2)
finishChan := atomic.Bool{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
tracker := ctx.GetSessionVars().MemTracker
for {
if finishChan.Load() {
break
}
// Mock consuming in another goroutine, so that we can test potential data race.
tracker.Consume(1)
time.Sleep(3 * time.Millisecond)
}
wg.Done()
}()
// Test is successful when all sqls are not hung
aggExec := buildHashAggExecutor(t, ctx, dataSource, testFuncName)
for range 30 {
randomFailTest(t, ctx, nil, dataSource, testFuncName)
randomFailTest(t, ctx, aggExec, dataSource, testFuncName)
}
finishChan.Store(true)
wg.Wait()
util.CheckNoLeakFiles(t, testFuncName)
}