435 lines
12 KiB
Go
435 lines
12 KiB
Go
// Copyright 2023 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package aggregate_test
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"math/rand"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pingcap/failpoint"
|
|
"github.com/pingcap/tidb/pkg/executor/aggfuncs"
|
|
"github.com/pingcap/tidb/pkg/executor/aggregate"
|
|
"github.com/pingcap/tidb/pkg/executor/internal/exec"
|
|
"github.com/pingcap/tidb/pkg/executor/internal/testutil"
|
|
"github.com/pingcap/tidb/pkg/expression"
|
|
"github.com/pingcap/tidb/pkg/expression/aggregation"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/sessionctx"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
"github.com/pingcap/tidb/pkg/util/chunk"
|
|
"github.com/pingcap/tidb/pkg/util/memory"
|
|
"github.com/pingcap/tidb/pkg/util/mock"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
// Chunk schema in this test file: | column0: string | column1: float64 |
|
|
|
|
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
|
|
type resultsContainer struct {
|
|
rows []chunk.Row
|
|
}
|
|
|
|
func (r *resultsContainer) add(chk *chunk.Chunk) {
|
|
iter := chunk.NewIterator4Chunk(chk.CopyConstruct())
|
|
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
|
|
r.rows = append(r.rows, row)
|
|
}
|
|
}
|
|
|
|
func (r *resultsContainer) check(expectRes map[string]float64) bool {
|
|
if len(r.rows) != len(expectRes) {
|
|
return false
|
|
}
|
|
|
|
cols := getColumns()
|
|
retFields := []*types.FieldType{cols[0].RetType, cols[1].RetType}
|
|
for _, row := range r.rows {
|
|
key := ""
|
|
for i, field := range retFields {
|
|
d := row.GetDatum(i, field)
|
|
resStr, err := d.ToString()
|
|
if err != nil {
|
|
panic("Fail to convert to string")
|
|
}
|
|
|
|
if i == 0 {
|
|
key = resStr
|
|
} else {
|
|
expectVal, ok := expectRes[key]
|
|
if !ok {
|
|
return false
|
|
}
|
|
if resStr != strconv.Itoa(int(expectVal)) {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|
|
|
|
func getRandString() string {
|
|
b := make([]byte, 5)
|
|
for i := range b {
|
|
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
|
}
|
|
return string(b)
|
|
}
|
|
|
|
func generateData(rowNum int, ndv int) ([]string, []float64) {
|
|
keys := make([]string, 0)
|
|
for i := 0; i < ndv; i++ {
|
|
keys = append(keys, getRandString())
|
|
}
|
|
|
|
col0Data := make([]string, 0)
|
|
col1Data := make([]float64, 0)
|
|
|
|
// Generate data
|
|
for i := 0; i < rowNum; i++ {
|
|
key := keys[i%ndv]
|
|
col0Data = append(col0Data, key)
|
|
col1Data = append(col1Data, 1) // Always 1
|
|
}
|
|
|
|
// Shuffle data
|
|
rand.Shuffle(rowNum, func(i, j int) {
|
|
col0Data[i], col0Data[j] = col0Data[j], col0Data[i]
|
|
// There is no need to shuffle col2Data as all of it's values are 1.
|
|
})
|
|
return col0Data, col1Data
|
|
}
|
|
|
|
func buildMockDataSource(opt testutil.MockDataSourceParameters, col0Data []string, col1Data []float64) *testutil.MockDataSource {
|
|
baseExec := exec.NewBaseExecutor(opt.Ctx, opt.DataSchema, 0)
|
|
mockDatasource := &testutil.MockDataSource{
|
|
BaseExecutor: baseExec,
|
|
ChunkPtr: 0,
|
|
P: opt,
|
|
GenData: nil,
|
|
Chunks: nil}
|
|
|
|
maxChunkSize := mockDatasource.MaxChunkSize()
|
|
rowNum := len(col0Data)
|
|
mockDatasource.GenData = make([]*chunk.Chunk, (rowNum+maxChunkSize-1)/maxChunkSize)
|
|
for i := range mockDatasource.GenData {
|
|
mockDatasource.GenData[i] = chunk.NewChunkWithCapacity(exec.RetTypes(mockDatasource), maxChunkSize)
|
|
}
|
|
|
|
for i := 0; i < rowNum; i++ {
|
|
chkIdx := i / maxChunkSize
|
|
mockDatasource.GenData[chkIdx].AppendString(0, col0Data[i])
|
|
mockDatasource.GenData[chkIdx].AppendFloat64(1, col1Data[i])
|
|
}
|
|
|
|
return mockDatasource
|
|
}
|
|
|
|
func generateResult(col0 []string) map[string]float64 {
|
|
result := make(map[string]float64, 0)
|
|
length := len(col0)
|
|
|
|
for i := 0; i < length; i++ {
|
|
_, ok := result[col0[i]]
|
|
if ok {
|
|
result[col0[i]]++
|
|
} else {
|
|
result[col0[i]] = 1
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
func getColumns() []*expression.Column {
|
|
return []*expression.Column{
|
|
{Index: 0, RetType: types.NewFieldType(mysql.TypeVarString)},
|
|
{Index: 1, RetType: types.NewFieldType(mysql.TypeDouble)},
|
|
}
|
|
}
|
|
|
|
func getSchema() *expression.Schema {
|
|
return expression.NewSchema(getColumns()...)
|
|
}
|
|
|
|
func getMockDataSourceParameters(ctx sessionctx.Context) testutil.MockDataSourceParameters {
|
|
return testutil.MockDataSourceParameters{
|
|
DataSchema: getSchema(),
|
|
Ctx: ctx,
|
|
}
|
|
}
|
|
|
|
func buildHashAggExecutor(t *testing.T, ctx sessionctx.Context, child exec.Executor) *aggregate.HashAggExec {
|
|
if err := ctx.GetSessionVars().SetSystemVar(variable.TiDBHashAggFinalConcurrency, fmt.Sprintf("%v", 5)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := ctx.GetSessionVars().SetSystemVar(variable.TiDBHashAggPartialConcurrency, fmt.Sprintf("%v", 5)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
childCols := getColumns()
|
|
schema := expression.NewSchema(childCols...)
|
|
groupItems := []expression.Expression{childCols[0]}
|
|
|
|
aggFirstRow, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncFirstRow, []expression.Expression{childCols[0]}, false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
aggFunc, err := aggregation.NewAggFuncDesc(ctx.GetExprCtx(), ast.AggFuncSum, []expression.Expression{childCols[1]}, false)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
aggFuncs := []*aggregation.AggFuncDesc{aggFirstRow, aggFunc}
|
|
|
|
aggExec := &aggregate.HashAggExec{
|
|
BaseExecutor: exec.NewBaseExecutor(ctx, schema, 0, child),
|
|
Sc: ctx.GetSessionVars().StmtCtx,
|
|
PartialAggFuncs: make([]aggfuncs.AggFunc, 0, len(aggFuncs)),
|
|
GroupByItems: groupItems,
|
|
IsUnparallelExec: false,
|
|
}
|
|
|
|
partialOrdinal := 0
|
|
for i, aggDesc := range aggFuncs {
|
|
ordinal := []int{partialOrdinal}
|
|
partialOrdinal++
|
|
partialAggDesc, finalDesc := aggDesc.Split(ordinal)
|
|
partialAggFunc := aggfuncs.Build(ctx.GetExprCtx(), partialAggDesc, i)
|
|
finalAggFunc := aggfuncs.Build(ctx.GetExprCtx(), finalDesc, i)
|
|
aggExec.PartialAggFuncs = append(aggExec.PartialAggFuncs, partialAggFunc)
|
|
aggExec.FinalAggFuncs = append(aggExec.FinalAggFuncs, finalAggFunc)
|
|
}
|
|
|
|
aggExec.SetChildren(0, child)
|
|
return aggExec
|
|
}
|
|
|
|
func initCtx(ctx *mock.Context, newRootExceedAction *testutil.MockActionOnExceed, hardLimitBytesNum int64, chkSize int) {
|
|
ctx.GetSessionVars().InitChunkSize = chkSize
|
|
ctx.GetSessionVars().MaxChunkSize = chkSize
|
|
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSession, hardLimitBytesNum)
|
|
ctx.GetSessionVars().TrackAggregateMemoryUsage = true
|
|
ctx.GetSessionVars().EnableParallelHashaggSpill = true
|
|
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
|
|
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
|
|
ctx.GetSessionVars().MemTracker.SetActionOnExceed(newRootExceedAction)
|
|
}
|
|
|
|
// select t0, sum(t1) from t group by t0;
|
|
func executeCorrecResultTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggExec, dataSource *testutil.MockDataSource, result map[string]float64) {
|
|
if aggExec == nil {
|
|
aggExec = buildHashAggExecutor(t, ctx, dataSource)
|
|
}
|
|
dataSource.PrepareChunks()
|
|
tmpCtx := context.Background()
|
|
chk := exec.NewFirstChunk(aggExec)
|
|
resContainer := resultsContainer{}
|
|
aggExec.Open(tmpCtx)
|
|
|
|
for {
|
|
err := aggExec.Next(tmpCtx, chk)
|
|
require.Equal(t, nil, err)
|
|
if chk.NumRows() == 0 {
|
|
break
|
|
}
|
|
resContainer.add(chk)
|
|
chk.Reset()
|
|
}
|
|
aggExec.Close()
|
|
|
|
require.True(t, aggExec.IsSpillTriggeredForTest())
|
|
require.True(t, resContainer.check(result))
|
|
}
|
|
|
|
func fallBackActionTest(t *testing.T) {
|
|
newRootExceedAction := new(testutil.MockActionOnExceed)
|
|
hardLimitBytesNum := int64(6000000)
|
|
|
|
ctx := mock.NewContext()
|
|
initCtx(ctx, newRootExceedAction, hardLimitBytesNum, 4096)
|
|
|
|
// Consume lots of memory in advance to help to trigger fallback action.
|
|
ctx.GetSessionVars().MemTracker.Consume(int64(float64(hardLimitBytesNum) * 0.799999))
|
|
|
|
rowNum := 10000 + rand.Intn(10000)
|
|
ndv := 5000 + rand.Intn(5000)
|
|
col1, col2 := generateData(rowNum, ndv)
|
|
opt := getMockDataSourceParameters(ctx)
|
|
dataSource := buildMockDataSource(opt, col1, col2)
|
|
|
|
aggExec := buildHashAggExecutor(t, ctx, dataSource)
|
|
dataSource.PrepareChunks()
|
|
tmpCtx := context.Background()
|
|
chk := exec.NewFirstChunk(aggExec)
|
|
resContainer := resultsContainer{}
|
|
aggExec.Open(tmpCtx)
|
|
|
|
for {
|
|
aggExec.Next(tmpCtx, chk)
|
|
if chk.NumRows() == 0 {
|
|
break
|
|
}
|
|
resContainer.add(chk)
|
|
chk.Reset()
|
|
}
|
|
aggExec.Close()
|
|
require.Less(t, 0, newRootExceedAction.GetTriggeredNum())
|
|
}
|
|
|
|
func randomFailTest(t *testing.T, ctx *mock.Context, aggExec *aggregate.HashAggExec, dataSource *testutil.MockDataSource) {
|
|
if aggExec == nil {
|
|
aggExec = buildHashAggExecutor(t, ctx, dataSource)
|
|
}
|
|
dataSource.PrepareChunks()
|
|
tmpCtx := context.Background()
|
|
chk := exec.NewFirstChunk(aggExec)
|
|
aggExec.Open(tmpCtx)
|
|
|
|
goRoutineWaiter := sync.WaitGroup{}
|
|
goRoutineWaiter.Add(1)
|
|
defer goRoutineWaiter.Wait()
|
|
|
|
once := sync.Once{}
|
|
|
|
go func() {
|
|
time.Sleep(time.Duration(rand.Int31n(300)) * time.Millisecond)
|
|
once.Do(func() {
|
|
aggExec.Close()
|
|
})
|
|
goRoutineWaiter.Done()
|
|
}()
|
|
|
|
for {
|
|
err := aggExec.Next(tmpCtx, chk)
|
|
if err != nil {
|
|
once.Do(func() {
|
|
err = aggExec.Close()
|
|
require.Equal(t, nil, err)
|
|
})
|
|
break
|
|
}
|
|
if chk.NumRows() == 0 {
|
|
break
|
|
}
|
|
chk.Reset()
|
|
}
|
|
once.Do(func() {
|
|
aggExec.Close()
|
|
})
|
|
}
|
|
|
|
func TestGetCorrectResult(t *testing.T) {
|
|
newRootExceedAction := new(testutil.MockActionOnExceed)
|
|
hardLimitBytesNum := int64(6000000)
|
|
|
|
ctx := mock.NewContext()
|
|
initCtx(ctx, newRootExceedAction, hardLimitBytesNum, 256)
|
|
|
|
rowNum := 100000
|
|
ndv := 50000
|
|
col1, col2 := generateData(rowNum, ndv)
|
|
result := generateResult(col1)
|
|
opt := getMockDataSourceParameters(ctx)
|
|
dataSource := buildMockDataSource(opt, col1, col2)
|
|
|
|
failpoint.Enable("github.com/pingcap/tidb/pkg/executor/aggregate/slowSomePartialWorkers", `return(true)`)
|
|
|
|
finished := atomic.Bool{}
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
tracker := ctx.GetSessionVars().MemTracker
|
|
for {
|
|
if finished.Load() {
|
|
break
|
|
}
|
|
// Mock consuming in another goroutine, so that we can test potential data race.
|
|
tracker.Consume(1)
|
|
time.Sleep(1 * time.Millisecond)
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
aggExec := buildHashAggExecutor(t, ctx, dataSource)
|
|
for i := 0; i < 5; i++ {
|
|
executeCorrecResultTest(t, ctx, nil, dataSource, result)
|
|
executeCorrecResultTest(t, ctx, aggExec, dataSource, result)
|
|
}
|
|
|
|
finished.Store(true)
|
|
wg.Wait()
|
|
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/pkg/executor/aggregate/slowSomePartialWorkers"))
|
|
}
|
|
|
|
func TestFallBackAction(t *testing.T) {
|
|
for i := 0; i < 50; i++ {
|
|
fallBackActionTest(t)
|
|
}
|
|
}
|
|
|
|
func TestRandomFail(t *testing.T) {
|
|
newRootExceedAction := new(testutil.MockActionOnExceed)
|
|
hardLimitBytesNum := int64(5000000)
|
|
|
|
ctx := mock.NewContext()
|
|
initCtx(ctx, newRootExceedAction, hardLimitBytesNum, 32)
|
|
|
|
failpoint.Enable("github.com/pingcap/tidb/pkg/executor/aggregate/enableAggSpillIntest", `return(true)`)
|
|
defer failpoint.Disable("github.com/pingcap/tidb/pkg/executor/aggregate/enableAggSpillIntest")
|
|
failpoint.Enable("github.com/pingcap/tidb/pkg/util/chunk/ChunkInDiskError", `return(true)`)
|
|
defer failpoint.Disable("github.com/pingcap/tidb/pkg/util/chunk/ChunkInDiskError")
|
|
rowNum := 100000 + rand.Intn(100000)
|
|
ndv := 50000 + rand.Intn(50000)
|
|
col1, col2 := generateData(rowNum, ndv)
|
|
opt := getMockDataSourceParameters(ctx)
|
|
dataSource := buildMockDataSource(opt, col1, col2)
|
|
|
|
finishChan := atomic.Bool{}
|
|
wg := sync.WaitGroup{}
|
|
wg.Add(1)
|
|
go func() {
|
|
tracker := ctx.GetSessionVars().MemTracker
|
|
for {
|
|
if finishChan.Load() {
|
|
break
|
|
}
|
|
// Mock consuming in another goroutine, so that we can test potential data race.
|
|
tracker.Consume(1)
|
|
time.Sleep(3 * time.Millisecond)
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
// Test is successful when all sqls are not hung
|
|
aggExec := buildHashAggExecutor(t, ctx, dataSource)
|
|
for i := 0; i < 30; i++ {
|
|
randomFailTest(t, ctx, nil, dataSource)
|
|
randomFailTest(t, ctx, aggExec, dataSource)
|
|
}
|
|
|
|
finishChan.Store(true)
|
|
wg.Wait()
|
|
}
|