238 lines
8.0 KiB
Go
238 lines
8.0 KiB
Go
// Copyright 2018 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package aggfuncs_test
|
|
|
|
import (
|
|
"fmt"
|
|
"testing"
|
|
"time"
|
|
|
|
. "github.com/pingcap/check"
|
|
"github.com/pingcap/parser"
|
|
"github.com/pingcap/parser/ast"
|
|
"github.com/pingcap/parser/mysql"
|
|
"github.com/pingcap/tidb/executor/aggfuncs"
|
|
"github.com/pingcap/tidb/expression"
|
|
"github.com/pingcap/tidb/expression/aggregation"
|
|
"github.com/pingcap/tidb/sessionctx"
|
|
"github.com/pingcap/tidb/types"
|
|
"github.com/pingcap/tidb/types/json"
|
|
"github.com/pingcap/tidb/util/chunk"
|
|
"github.com/pingcap/tidb/util/mock"
|
|
)
|
|
|
|
var _ = Suite(&testSuite{})
|
|
|
|
func TestT(t *testing.T) {
|
|
CustomVerboseFlag = true
|
|
*CustomParallelSuiteFlag = true
|
|
TestingT(t)
|
|
}
|
|
|
|
type testSuite struct {
|
|
*parser.Parser
|
|
ctx sessionctx.Context
|
|
}
|
|
|
|
func (s *testSuite) SetUpSuite(c *C) {
|
|
s.Parser = parser.New()
|
|
s.ctx = mock.NewContext()
|
|
s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local
|
|
}
|
|
|
|
func (s *testSuite) TearDownSuite(c *C) {
|
|
}
|
|
|
|
func (s *testSuite) SetUpTest(c *C) {
|
|
s.ctx.GetSessionVars().PlanColumnID = 0
|
|
}
|
|
|
|
func (s *testSuite) TearDownTest(c *C) {
|
|
s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil)
|
|
}
|
|
|
|
type aggTest struct {
|
|
dataType *types.FieldType
|
|
numRows int
|
|
dataGen func(i int) types.Datum
|
|
funcName string
|
|
results []types.Datum
|
|
}
|
|
|
|
func (s *testSuite) testMergePartialResult(c *C, p aggTest) {
|
|
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
|
|
for i := 0; i < p.numRows; i++ {
|
|
dt := p.dataGen(i)
|
|
srcChk.AppendDatum(0, &dt)
|
|
}
|
|
iter := chunk.NewIterator4Chunk(srcChk)
|
|
|
|
args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
|
|
if p.funcName == ast.AggFuncGroupConcat {
|
|
args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)})
|
|
}
|
|
desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
|
|
partialDesc, finalDesc := desc.Split([]int{0, 1})
|
|
|
|
// build partial func for partial phase.
|
|
partialFunc := aggfuncs.Build(s.ctx, partialDesc, 0)
|
|
partialResult := partialFunc.AllocPartialResult()
|
|
|
|
// build final func for final phase.
|
|
finalFunc := aggfuncs.Build(s.ctx, finalDesc, 0)
|
|
finalPr := finalFunc.AllocPartialResult()
|
|
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, 1)
|
|
|
|
// update partial result.
|
|
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
|
|
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
|
|
}
|
|
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
|
|
dt := resultChk.GetRow(0).GetDatum(0, p.dataType)
|
|
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
|
|
err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
|
|
c.Assert(err, IsNil)
|
|
partialFunc.ResetPartialResult(partialResult)
|
|
|
|
iter.Begin()
|
|
iter.Next()
|
|
for row := iter.Next(); row != iter.End(); row = iter.Next() {
|
|
partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult)
|
|
}
|
|
resultChk.Reset()
|
|
partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk)
|
|
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
|
|
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr)
|
|
c.Assert(err, IsNil)
|
|
|
|
resultChk.Reset()
|
|
err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
|
|
c.Assert(err, IsNil)
|
|
|
|
dt = resultChk.GetRow(0).GetDatum(0, p.dataType)
|
|
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
}
|
|
|
|
func buildAggTester(funcName string, tp byte, numRows int, results ...interface{}) aggTest {
|
|
return buildAggTesterWithFieldType(funcName, types.NewFieldType(tp), numRows, results...)
|
|
}
|
|
|
|
func buildAggTesterWithFieldType(funcName string, ft *types.FieldType, numRows int, results ...interface{}) aggTest {
|
|
pt := aggTest{
|
|
dataType: ft,
|
|
numRows: numRows,
|
|
funcName: funcName,
|
|
dataGen: getDataGenFunc(ft),
|
|
}
|
|
for _, result := range results {
|
|
pt.results = append(pt.results, types.NewDatum(result))
|
|
}
|
|
return pt
|
|
}
|
|
|
|
func getDataGenFunc(ft *types.FieldType) func(i int) types.Datum {
|
|
switch ft.Tp {
|
|
case mysql.TypeLonglong:
|
|
return func(i int) types.Datum { return types.NewIntDatum(int64(i)) }
|
|
case mysql.TypeFloat:
|
|
return func(i int) types.Datum { return types.NewFloat32Datum(float32(i)) }
|
|
case mysql.TypeNewDecimal:
|
|
return func(i int) types.Datum { return types.NewDecimalDatum(types.NewDecFromInt(int64(i))) }
|
|
case mysql.TypeDouble:
|
|
return func(i int) types.Datum { return types.NewFloat64Datum(float64(i)) }
|
|
case mysql.TypeString:
|
|
return func(i int) types.Datum { return types.NewStringDatum(fmt.Sprintf("%d", i)) }
|
|
case mysql.TypeDate:
|
|
return func(i int) types.Datum { return types.NewTimeDatum(types.TimeFromDays(int64(i + 365))) }
|
|
case mysql.TypeDuration:
|
|
return func(i int) types.Datum { return types.NewDurationDatum(types.Duration{Duration: time.Duration(i)}) }
|
|
case mysql.TypeJSON:
|
|
return func(i int) types.Datum { return types.NewDatum(json.CreateBinary(int64(i))) }
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *testSuite) testAggFunc(c *C, p aggTest) {
|
|
srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows)
|
|
for i := 0; i < p.numRows; i++ {
|
|
dt := p.dataGen(i)
|
|
srcChk.AppendDatum(0, &dt)
|
|
}
|
|
srcChk.AppendDatum(0, &types.Datum{})
|
|
|
|
args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}}
|
|
if p.funcName == ast.AggFuncGroupConcat {
|
|
args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)})
|
|
}
|
|
desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false)
|
|
finalFunc := aggfuncs.Build(s.ctx, desc, 0)
|
|
finalPr := finalFunc.AllocPartialResult()
|
|
resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{desc.RetTp}, 1)
|
|
|
|
iter := chunk.NewIterator4Chunk(srcChk)
|
|
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
|
|
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
|
|
}
|
|
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
|
|
dt := resultChk.GetRow(0).GetDatum(0, desc.RetTp)
|
|
result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
|
|
// test the empty input
|
|
resultChk.Reset()
|
|
finalFunc.ResetPartialResult(finalPr)
|
|
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
|
|
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
|
|
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
|
|
// test the agg func with distinct
|
|
desc = aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, true)
|
|
finalFunc = aggfuncs.Build(s.ctx, desc, 0)
|
|
finalPr = finalFunc.AllocPartialResult()
|
|
|
|
resultChk.Reset()
|
|
iter = chunk.NewIterator4Chunk(srcChk)
|
|
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
|
|
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
|
|
}
|
|
for row := iter.Begin(); row != iter.End(); row = iter.Next() {
|
|
finalFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, finalPr)
|
|
}
|
|
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
|
|
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
|
|
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
|
|
// test the empty input
|
|
resultChk.Reset()
|
|
finalFunc.ResetPartialResult(finalPr)
|
|
finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk)
|
|
dt = resultChk.GetRow(0).GetDatum(0, desc.RetTp)
|
|
result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0])
|
|
c.Assert(err, IsNil)
|
|
c.Assert(result, Equals, 0)
|
|
}
|