Files
tidb/executor/aggfuncs/aggfunc_test.go
2022-09-05 19:42:55 +08:00

920 lines
31 KiB
Go

// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package aggfuncs_test
import (
"fmt"
"strconv"
"testing"
"time"
"unsafe"
"github.com/dgryski/go-farm"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/executor/aggfuncs"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/planner/util"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/codec"
"github.com/pingcap/tidb/util/collate"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/mock"
"github.com/pingcap/tidb/util/set"
"github.com/stretchr/testify/require"
)
const (
// separator argument for group_concat() test cases
separator = " "
)
type aggTest struct {
dataType *types.FieldType
numRows int
dataGen func(i int) types.Datum
funcName string
results []types.Datum
orderBy bool
}
func (p *aggTest) genSrcChk() *chunk.Chunk {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
}
srcChk.AppendDatum(0, &types.Datum{})
return srcChk
}
// messUpChunk messes up the chunk for testing memory reference.
func (p *aggTest) messUpChunk(c *chunk.Chunk) {
for i := 0; i < p.numRows; i++ {
raw := c.Column(0).GetRaw(i)
for i := range raw {
raw[i] = 255
}
}
}
type multiArgsAggTest struct {
dataTypes []*types.FieldType
retType *types.FieldType
numRows int
dataGens []func(i int) types.Datum
funcName string
results []types.Datum
orderBy bool
}
func (p *multiArgsAggTest) genSrcChk() *chunk.Chunk {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})
return srcChk
}
// messUpChunk messes up the chunk for testing memory reference.
func (p *multiArgsAggTest) messUpChunk(c *chunk.Chunk) {
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
raw := c.Column(j).GetRaw(i)
for i := range raw {
raw[i] = 255
}
}
}
}
type updateMemDeltaGens func(*chunk.Chunk, *types.FieldType) (memDeltas []int64, err error)
func defaultUpdateMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memDeltas []int64, err error) {
memDeltas = make([]int64, 0)
for i := 0; i < srcChk.NumRows(); i++ {
memDeltas = append(memDeltas, int64(0))
}
return memDeltas, nil
}
func approxCountDistinctUpdateMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memDeltas []int64, err error) {
memDeltas = make([]int64, 0)
buf := make([]byte, 8)
p := aggfuncs.NewPartialResult4ApproxCountDistinct()
for i := 0; i < srcChk.NumRows(); i++ {
row := srcChk.GetRow(i)
if row.IsNull(0) {
memDeltas = append(memDeltas, int64(0))
continue
}
oldMemUsage := p.MemUsage()
switch dataType.GetType() {
case mysql.TypeLonglong:
val := row.GetInt64(0)
*(*int64)(unsafe.Pointer(&buf[0])) = val
case mysql.TypeString:
val := row.GetString(0)
buf = codec.EncodeCompactBytes(buf, hack.Slice(val))
default:
return memDeltas, errors.Errorf("unsupported type - %v", dataType.GetType())
}
x := farm.Hash64(buf)
p.InsertHash64(x)
newMemUsage := p.MemUsage()
memDelta := newMemUsage - oldMemUsage
memDeltas = append(memDeltas, memDelta)
}
return memDeltas, nil
}
func distinctUpdateMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memDeltas []int64, err error) {
valSet := set.NewStringSet()
memDeltas = make([]int64, 0)
for i := 0; i < srcChk.NumRows(); i++ {
row := srcChk.GetRow(i)
if row.IsNull(0) {
memDeltas = append(memDeltas, int64(0))
continue
}
val := ""
memDelta := int64(0)
switch dataType.GetType() {
case mysql.TypeLonglong:
val = strconv.FormatInt(row.GetInt64(0), 10)
case mysql.TypeFloat:
val = strconv.FormatFloat(float64(row.GetFloat32(0)), 'f', 6, 64)
case mysql.TypeDouble:
val = strconv.FormatFloat(row.GetFloat64(0), 'f', 6, 64)
case mysql.TypeNewDecimal:
decimal := row.GetMyDecimal(0)
hash, err := decimal.ToHashKey()
if err != nil {
memDeltas = append(memDeltas, int64(0))
continue
}
val = string(hack.String(hash))
memDelta = int64(len(val))
case mysql.TypeString:
val = row.GetString(0)
memDelta = int64(len(val))
case mysql.TypeDate:
val = row.GetTime(0).String()
// the distinct count aggFunc need 16 bytes to encode the Datetime type.
memDelta = 16
case mysql.TypeDuration:
val = strconv.FormatInt(row.GetInt64(0), 10)
case mysql.TypeJSON:
jsonVal := row.GetJSON(0)
bytes := make([]byte, 0)
bytes = append(bytes, jsonVal.TypeCode)
bytes = append(bytes, jsonVal.Value...)
val = string(bytes)
memDelta = int64(len(val))
default:
return memDeltas, errors.Errorf("unsupported type - %v", dataType.GetType())
}
if valSet.Exist(val) {
memDeltas = append(memDeltas, int64(0))
continue
}
valSet.Insert(val)
memDeltas = append(memDeltas, memDelta)
}
return memDeltas, nil
}
func rowMemDeltaGens(srcChk *chunk.Chunk, dataType *types.FieldType) (memDeltas []int64, err error) {
memDeltas = make([]int64, 0)
for i := 0; i < srcChk.NumRows(); i++ {
memDelta := aggfuncs.DefRowSize
memDeltas = append(memDeltas, memDelta)
}
return memDeltas, nil
}
type multiArgsUpdateMemDeltaGens func(*chunk.Chunk, []*types.FieldType, []*util.ByItems) (memDeltas []int64, err error)
type aggMemTest struct {
aggTest aggTest
allocMemDelta int64
updateMemDeltaGens updateMemDeltaGens
isDistinct bool
}
func buildAggMemTester(funcName string, tp byte, numRows int, allocMemDelta int64, updateMemDeltaGens updateMemDeltaGens, isDistinct bool) aggMemTest {
aggTest := buildAggTester(funcName, tp, numRows)
pt := aggMemTest{
aggTest: aggTest,
allocMemDelta: allocMemDelta,
updateMemDeltaGens: updateMemDeltaGens,
isDistinct: isDistinct,
}
return pt
}
type multiArgsAggMemTest struct {
multiArgsAggTest multiArgsAggTest
allocMemDelta int64
multiArgsUpdateMemDeltaGens multiArgsUpdateMemDeltaGens
isDistinct bool
}
func buildMultiArgsAggMemTester(funcName string, tps []byte, rt byte, numRows int, allocMemDelta int64, updateMemDeltaGens multiArgsUpdateMemDeltaGens, isDistinct bool) multiArgsAggMemTest {
multiArgsAggTest := buildMultiArgsAggTester(funcName, tps, rt, numRows)
pt := multiArgsAggMemTest{
multiArgsAggTest: multiArgsAggTest,
allocMemDelta: allocMemDelta,
multiArgsUpdateMemDeltaGens: updateMemDeltaGens,
isDistinct: isDistinct,
}
return pt
}
func testMergePartialResult(t *testing.T, p aggTest) {
ctx := mock.NewContext()
srcChk := p.genSrcChk()
iter := chunk.NewIterator4Chunk(srcChk)
args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.GetCollate())
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
partialDesc, finalDesc := desc.Split([]int{0, 1})
// build partial func for partial phase.
partialFunc := aggfuncs.Build(ctx, partialDesc, 0)
partialResult, _ := partialFunc.AllocPartialResult()
// build final func for final phase.
finalFunc := aggfuncs.Build(ctx, finalDesc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, 1)
if p.funcName == ast.AggFuncApproxCountDistinct {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeString)}, 1)
}
if p.funcName == ast.AggFuncJsonArrayagg {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeJSON)}, 1)
}
// update partial result.
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
_, err = partialFunc.UpdatePartialResult(ctx, []chunk.Row{row}, partialResult)
require.NoError(t, err)
}
p.messUpChunk(srcChk)
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, p.dataType)
if p.funcName == ast.AggFuncApproxCountDistinct {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString))
}
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
require.NoError(t, err)
partialFunc.ResetPartialResult(partialResult)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
_, err = partialFunc.UpdatePartialResult(ctx, []chunk.Row{row}, partialResult)
require.NoError(t, err)
}
p.messUpChunk(srcChk)
resultChk.Reset()
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
if p.funcName == ast.AggFuncApproxCountDistinct {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeString))
}
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
require.NoError(t, err)
if p.funcName == ast.AggFuncApproxCountDistinct {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1)
}
if p.funcName == ast.AggFuncJsonArrayagg {
resultChk = chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeJSON)}, 1)
}
resultChk.Reset()
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
if p.funcName == ast.AggFuncApproxCountDistinct {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeLonglong))
}
if p.funcName == ast.AggFuncJsonArrayagg {
dt = resultChk.GetRow(0).GetDatum(0, types.NewFieldType(mysql.TypeJSON))
}
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[2])
}
func buildAggTester(funcName string, tp byte, numRows int, results ...interface{}) aggTest {
return buildAggTesterWithFieldType(funcName, types.NewFieldType(tp), numRows, results...)
}
func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows int, results ...interface{}) aggTest {
pt := aggTest{
dataType: ft,
numRows: numRows,
funcName: funcName,
dataGen: getDataGenFunc(ft),
}
for _, result := range results {
pt.results = append(pt.results, types.NewDatum(result))
}
return pt
}
func testMultiArgsMergePartialResult(t *testing.T, ctx sessionctx.Context, p multiArgsAggTest) {
srcChk := p.genSrcChk()
iter := chunk.NewIterator4Chunk(srcChk)
args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().GetCollate())
partialDesc, finalDesc := desc.Split([]int{0, 1})
// build partial func for partial phase.
partialFunc := aggfuncs.Build(ctx, partialDesc, 0)
partialResult, _ := partialFunc.AllocPartialResult()
// build final func for final phase.
finalFunc := aggfuncs.Build(ctx, finalDesc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.retType}, 1)
// update partial result.
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
// FIXME: cannot assert error since there are cases of error, e.g. JSON documents may not contain NULL member
_, _ = partialFunc.UpdatePartialResult(ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, p.retType)
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zero(t, result)
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
require.NoError(t, err)
partialFunc.ResetPartialResult(partialResult)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
iter.Begin()
iter.Next()
for row := iter.Next(); row != iter.End(); row = iter.Next() {
// FIXME: cannot check error
_, _ = partialFunc.UpdatePartialResult(ctx, []chunk.Row{row}, partialResult)
}
p.messUpChunk(srcChk)
resultChk.Reset()
err = partialFunc.AppendFinalResult2Chunk(ctx, partialResult, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zero(t, result)
_, err = finalFunc.MergePartialResult(ctx, partialResult, finalPr)
require.NoError(t, err)
resultChk.Reset()
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, p.retType)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[2], ctor)
require.NoError(t, err)
require.Zero(t, result)
}
// for multiple args in aggfuncs such as json_objectagg(c1, c2)
func buildMultiArgsAggTester(funcName string, tps []byte, rt byte, numRows int, results ...interface{}) multiArgsAggTest {
fts := make([]*types.FieldType, len(tps))
for i := 0; i < len(tps); i++ {
fts[i] = types.NewFieldType(tps[i])
}
return buildMultiArgsAggTesterWithFieldType(funcName, fts, types.NewFieldType(rt), numRows, results...)
}
func buildMultiArgsAggTesterWithFieldType(funcName string, fts []*types.FieldType, rt *types.FieldType, numRows int, results ...interface{}) multiArgsAggTest {
dataGens := make([]func(i int) types.Datum, len(fts))
for i := 0; i < len(fts); i++ {
dataGens[i] = getDataGenFunc(fts[i])
}
mt := multiArgsAggTest{
dataTypes: fts,
retType: rt,
numRows: numRows,
funcName: funcName,
dataGens: dataGens,
}
for _, result := range results {
mt.results = append(mt.results, types.NewDatum(result))
}
return mt
}
func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum {
switch ft.GetType() {
case mysql.TypeLonglong:
return func(i int) types.Datum { return types.NewIntDatum(int64(i)) }
case mysql.TypeFloat:
return func(i int) types.Datum { return types.NewFloat32Datum(float32(i)) }
case mysql.TypeNewDecimal:
return func(i int) types.Datum { return types.NewDecimalDatum(types.NewDecFromInt(int64(i))) }
case mysql.TypeDouble:
return func(i int) types.Datum { return types.NewFloat64Datum(float64(i)) }
case mysql.TypeString:
return func(i int) types.Datum { return types.NewStringDatum(fmt.Sprintf("%d", i)) }
case mysql.TypeDate:
return func(i int) types.Datum { return types.NewTimeDatum(types.TimeFromDays(int64(i + 365))) }
case mysql.TypeDuration:
return func(i int) types.Datum { return types.NewDurationDatum(types.Duration{Duration: time.Duration(i)}) }
case mysql.TypeJSON:
return func(i int) types.Datum { return types.NewDatum(types.CreateBinaryJSON(int64(i))) }
case mysql.TypeEnum:
elems := []string{"e", "d", "c", "b", "a"}
return func(i int) types.Datum {
e, _ := types.ParseEnumValue(elems, uint64(i+1))
return types.NewCollateMysqlEnumDatum(e, ft.GetCollate())
}
case mysql.TypeSet:
elems := []string{"e", "d", "c", "b", "a"}
return func(i int) types.Datum {
e, _ := types.ParseSetValue(elems, uint64(i+1))
return types.NewMysqlSetDatum(e, ft.GetCollate())
}
}
return nil
}
func testAggFunc(t *testing.T, p aggTest) {
srcChk := p.genSrcChk()
ctx := mock.NewContext()
args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.GetCollate())
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
if p.funcName == ast.AggFuncApproxPercentile {
args = append(args, &expression.Constant{Value: types.NewIntDatum(50), RetType: types.NewFieldType(mysql.TypeLong)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
_, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
require.NoError(t, err)
}
p.messUpChunk(srcChk)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])
// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])
// test the agg func with distinct
desc, err = aggregation.NewAggFuncDesc(ctx, p.funcName, args, true)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(ctx, desc, 0)
finalPr, _ = finalFunc.AllocPartialResult()
resultChk.Reset()
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
_, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
require.NoError(t, err)
}
p.messUpChunk(srcChk)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
_, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
require.NoError(t, err)
}
p.messUpChunk(srcChk)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[1])
// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Equalf(t, 0, result, "%v != %v", dt.String(), p.results[0])
}
func testAggFuncWithoutDistinct(t *testing.T, p aggTest) {
srcChk := p.genSrcChk()
args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
ctor := collate.GetCollator(p.dataType.GetCollate())
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
if p.funcName == ast.AggFuncApproxPercentile {
args = append(args, &expression.Constant{Value: types.NewIntDatum(50), RetType: types.NewFieldType(mysql.TypeLong)})
}
ctx := mock.NewContext()
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
_, err = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
require.NoError(t, err)
}
p.messUpChunk(srcChk)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])
// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[0])
}
func testAggMemFunc(t *testing.T, p aggMemTest) {
srcChk := p.aggTest.genSrcChk()
ctx := mock.NewContext()
args := []expression.Expression{&expression.Column{RetType: p.aggTest.dataType, Index: 0}}
if p.aggTest.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.aggTest.funcName, args, p.isDistinct)
require.NoError(t, err)
if p.aggTest.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, memDelta := finalFunc.AllocPartialResult()
require.Equal(t, p.allocMemDelta, memDelta)
updateMemDeltas, err := p.updateMemDeltaGens(srcChk, p.aggTest.dataType)
require.NoError(t, err)
iter := chunk.NewIterator4Chunk(srcChk)
i := 0
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
memDelta, err := finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
require.NoError(t, err)
require.Equal(t, updateMemDeltas[i], memDelta)
i++
}
}
func testMultiArgsAggFunc(t *testing.T, ctx sessionctx.Context, p multiArgsAggTest) {
srcChk := p.genSrcChk()
args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
ctor := collate.GetCollator(args[0].GetType().GetCollate())
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, _ := finalFunc.AllocPartialResult()
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
// FIXME: cannot assert error since there are cases of error, e.g. rows were cut by GROUPCONCAT
_, _ = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err := dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])
// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[0])
// test the agg func with distinct
desc, err = aggregation.NewAggFuncDesc(ctx, p.funcName, args, true)
require.NoError(t, err)
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(ctx, desc, 0)
finalPr, _ = finalFunc.AllocPartialResult()
resultChk.Reset()
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
// FIXME: cannot check error
_, _ = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
srcChk = p.genSrcChk()
iter = chunk.NewIterator4Chunk(srcChk)
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
// FIXME: cannot check error
_, _ = finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
}
p.messUpChunk(srcChk)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[1], ctor)
require.NoError(t, err)
require.Zerof(t, result, "%v != %v", dt.String(), p.results[1])
// test the empty input
resultChk.Reset()
finalFunc.ResetPartialResult(finalPr)
err = finalFunc.AppendFinalResult2Chunk(ctx, finalPr, resultChk)
require.NoError(t, err)
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
result, err = dt.Compare(ctx.GetSessionVars().StmtCtx, &p.results[0], ctor)
require.NoError(t, err)
require.Zero(t, result)
}
func testMultiArgsAggMemFunc(t *testing.T, p multiArgsAggMemTest) {
srcChk := p.multiArgsAggTest.genSrcChk()
ctx := mock.NewContext()
args := make([]expression.Expression, len(p.multiArgsAggTest.dataTypes))
for k := 0; k < len(p.multiArgsAggTest.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.multiArgsAggTest.dataTypes[k], Index: k}
}
if p.multiArgsAggTest.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.multiArgsAggTest.funcName, args, p.isDistinct)
require.NoError(t, err)
if p.multiArgsAggTest.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(ctx, desc, 0)
finalPr, memDelta := finalFunc.AllocPartialResult()
require.Equal(t, p.allocMemDelta, memDelta)
updateMemDeltas, err := p.multiArgsUpdateMemDeltaGens(srcChk, p.multiArgsAggTest.dataTypes, desc.OrderByItems)
require.NoError(t, err)
iter := chunk.NewIterator4Chunk(srcChk)
i := 0
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
memDelta, _ := finalFunc.UpdatePartialResult(ctx, []chunk.Row{row}, finalPr)
require.Equal(t, updateMemDeltas[i], memDelta)
i++
}
}
func benchmarkAggFunc(b *testing.B, ctx sessionctx.Context, p aggTest) {
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
for i := 0; i < p.numRows; i++ {
dt := p.dataGen(i)
srcChk.AppendDatum(0, &dt)
}
srcChk.AppendDatum(0, &types.Datum{})
args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(ctx, desc, 0)
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
input := make([]chunk.Row, 0, iter.Len())
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
input = append(input, row)
}
b.Run(fmt.Sprintf("%v/%v", p.funcName, p.dataType), func(b *testing.B) {
baseBenchmarkAggFunc(b, ctx, finalFunc, input, resultChk)
})
desc, err = aggregation.NewAggFuncDesc(ctx, p.funcName, args, true)
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(ctx, desc, 0)
resultChk.Reset()
b.Run(fmt.Sprintf("%v(distinct)/%v", p.funcName, p.dataType), func(b *testing.B) {
baseBenchmarkAggFunc(b, ctx, finalFunc, input, resultChk)
})
}
func benchmarkMultiArgsAggFunc(b *testing.B, ctx sessionctx.Context, p multiArgsAggTest) {
srcChk := chunk.NewChunkWithCapacity(p.dataTypes, p.numRows)
for i := 0; i < p.numRows; i++ {
for j := 0; j < len(p.dataGens); j++ {
fdt := p.dataGens[j](i)
srcChk.AppendDatum(j, &fdt)
}
}
srcChk.AppendDatum(0, &types.Datum{})
args := make([]expression.Expression, len(p.dataTypes))
for k := 0; k < len(p.dataTypes); k++ {
args[k] = &expression.Column{RetType: p.dataTypes[k], Index: k}
}
if p.funcName == ast.AggFuncGroupConcat {
args = append(args, &expression.Constant{Value: types.NewStringDatum(separator), RetType: types.NewFieldType(mysql.TypeString)})
}
desc, err := aggregation.NewAggFuncDesc(ctx, p.funcName, args, false)
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc := aggfuncs.Build(ctx, desc, 0)
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
iter := chunk.NewIterator4Chunk(srcChk)
input := make([]chunk.Row, 0, iter.Len())
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
input = append(input, row)
}
b.Run(fmt.Sprintf("%v/%v", p.funcName, p.dataTypes), func(b *testing.B) {
baseBenchmarkAggFunc(b, ctx, finalFunc, input, resultChk)
})
desc, err = aggregation.NewAggFuncDesc(ctx, p.funcName, args, true)
if err != nil {
b.Fatal(err)
}
if p.orderBy {
desc.OrderByItems = []*util.ByItems{
{Expr: args[0], Desc: true},
}
}
finalFunc = aggfuncs.Build(ctx, desc, 0)
resultChk.Reset()
b.Run(fmt.Sprintf("%v(distinct)/%v", p.funcName, p.dataTypes), func(b *testing.B) {
baseBenchmarkAggFunc(b, ctx, finalFunc, input, resultChk)
})
}
func baseBenchmarkAggFunc(b *testing.B, ctx sessionctx.Context, finalFunc aggfuncs.AggFunc, input []chunk.Row, output *chunk.Chunk) {
finalPr, _ := finalFunc.AllocPartialResult()
output.Reset()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := finalFunc.UpdatePartialResult(ctx, input, finalPr)
if err != nil {
b.Fatal(err)
}
b.StopTimer()
output.Reset()
b.StartTimer()
}
}