diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index c71333afd5..fa53876710 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -14,12 +14,21 @@ package aggfuncs_test import ( + "fmt" "testing" "time" . "github.com/pingcap/check" "github.com/pingcap/parser" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/executor/aggfuncs" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/mock" ) @@ -52,3 +61,107 @@ func (s *testSuite) SetUpTest(c *C) { func (s *testSuite) TearDownTest(c *C) { s.ctx.GetSessionVars().StmtCtx.SetWarnings(nil) } + +type aggMergeTest struct { + dataType *types.FieldType + numRows int + dataGen func(i int) types.Datum + funcName string + results []types.Datum +} + +func (s *testSuite) testMergePartialResult(c *C, p aggMergeTest) { + srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, p.numRows) + for i := 0; i < p.numRows; i++ { + dt := p.dataGen(i) + srcChk.AppendDatum(0, &dt) + } + iter := chunk.NewIterator4Chunk(srcChk) + + args := []expression.Expression{&expression.Column{RetType: p.dataType, Index: 0}} + if p.funcName == ast.AggFuncGroupConcat { + args = append(args, &expression.Constant{Value: types.NewStringDatum(" "), RetType: types.NewFieldType(mysql.TypeString)}) + } + desc := aggregation.NewAggFuncDesc(s.ctx, p.funcName, args, false) + partialDesc, finalDesc := desc.Split([]int{0, 1}) + + // build partial func for partial phase. + partialFunc := aggfuncs.Build(s.ctx, partialDesc, 0) + partialResult := partialFunc.AllocPartialResult() + + // build final func for final phase. + finalFunc := aggfuncs.Build(s.ctx, finalDesc, 0) + finalPr := finalFunc.AllocPartialResult() + resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{p.dataType}, 1) + + // update partial result. + for row := iter.Begin(); row != iter.End(); row = iter.Next() { + partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) + } + partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) + dt := resultChk.GetRow(0).GetDatum(0, p.dataType) + result, err := dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[0]) + c.Assert(err, IsNil) + c.Assert(result, Equals, 0) + + err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) + c.Assert(err, IsNil) + partialFunc.ResetPartialResult(partialResult) + + iter.Begin() + iter.Next() + for row := iter.Next(); row != iter.End(); row = iter.Next() { + partialFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialResult) + } + resultChk.Reset() + partialFunc.AppendFinalResult2Chunk(s.ctx, partialResult, resultChk) + dt = resultChk.GetRow(0).GetDatum(0, p.dataType) + result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[1]) + c.Assert(err, IsNil) + c.Assert(result, Equals, 0) + err = finalFunc.MergePartialResult(s.ctx, partialResult, finalPr) + c.Assert(err, IsNil) + + resultChk.Reset() + err = finalFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) + c.Assert(err, IsNil) + + dt = resultChk.GetRow(0).GetDatum(0, p.dataType) + result, err = dt.CompareDatum(s.ctx.GetSessionVars().StmtCtx, &p.results[2]) + c.Assert(err, IsNil) + c.Assert(result, Equals, 0) +} + +func buildAggMergeTester(funcName string, tp byte, numRows int, results ...interface{}) aggMergeTest { + return buildAggMergeTesterWithFieldType(funcName, types.NewFieldType(tp), numRows, results...) +} + +func buildAggMergeTesterWithFieldType(funcName string, ft *types.FieldType, numRows int, results ...interface{}) aggMergeTest { + pt := aggMergeTest{ + dataType: ft, + numRows: numRows, + funcName: funcName, + } + for _, result := range results { + pt.results = append(pt.results, types.NewDatum(result)) + } + switch ft.Tp { + case mysql.TypeLonglong: + pt.dataGen = func(i int) types.Datum { return types.NewIntDatum(int64(i)) } + case mysql.TypeFloat: + pt.dataGen = func(i int) types.Datum { return types.NewFloat32Datum(float32(i)) } + case mysql.TypeNewDecimal: + pt.dataGen = func(i int) types.Datum { return types.NewDecimalDatum(types.NewDecFromInt(int64(i))) } + case mysql.TypeDouble: + pt.dataGen = func(i int) types.Datum { return types.NewFloat64Datum(float64(i)) } + case mysql.TypeString: + pt.dataGen = func(i int) types.Datum { return types.NewStringDatum(fmt.Sprintf("%d", i)) } + case mysql.TypeDate: + pt.dataGen = func(i int) types.Datum { return types.NewTimeDatum(types.TimeFromDays(int64(i))) } + case mysql.TypeDuration: + pt.dataGen = func(i int) types.Datum { return types.NewDurationDatum(types.Duration{Duration: time.Duration(i)}) } + case mysql.TypeJSON: + pt.dataGen = func(i int) types.Datum { return types.NewDatum(json.CreateBinary(int64(i))) } + } + return pt +} diff --git a/executor/aggfuncs/func_avg_test.go b/executor/aggfuncs/func_avg_test.go index 6a52a5cc34..17cb6ffeaf 100644 --- a/executor/aggfuncs/func_avg_test.go +++ b/executor/aggfuncs/func_avg_test.go @@ -17,112 +17,14 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/executor/aggfuncs" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/aggregation" - "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/chunk" ) -func (s *testSuite) TestMergePartialResult4AvgDecimal(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendMyDecimal(0, types.NewDecFromInt(i)) +func (s *testSuite) TestMergePartialResult4Avg(c *C) { + tests := []aggMergeTest{ + buildAggMergeTester(ast.AggFuncAvg, mysql.TypeNewDecimal, 5, 2.0, 3.0, 2.375), + buildAggMergeTester(ast.AggFuncAvg, mysql.TypeDouble, 5, 2.0, 3.0, 2.375), } - iter := chunk.NewIterator4Chunk(srcChk) - - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncAvg, args, false) - partialDesc, finalDesc := desc.Split([]int{0, 1}) - - // build avg func for partial phase. - partialAvgFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialAvgFunc.AllocPartialResult() - partialPr2 := partialAvgFunc.AllocPartialResult() - - // build final func for final phase. - finalAvgFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalAvgFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialAvgFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) + for _, test := range tests { + s.testMergePartialResult(c, test) } - // (0+1+2+3+4) / 5 - partialAvgFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(2)) == 0, IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialAvgFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - // (2+3+4) / 3 - partialAvgFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromFloatForTest(3)) == 0, IsTrue) - - // merge two partial results. - err := finalAvgFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalAvgFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalAvgFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - // (10 + 9) / 8 - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromFloatForTest(2.375)) == 0, IsTrue) -} - -func (s *testSuite) TestMergePartialResult4AvgFloat(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendFloat64(0, float64(i)) - } - iter := chunk.NewIterator4Chunk(srcChk) - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncAvg, args, false) - partialDesc, finalDesc := desc.Split([]int{0, 1}) - - // build avg func for partial phase. - partialAvgFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialAvgFunc.AllocPartialResult() - partialPr2 := partialAvgFunc.AllocPartialResult() - - // build final func for final phase. - finalAvgFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalAvgFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialAvgFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) - } - partialAvgFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - // (0+1+2+3+4) / 5 - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(2), IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialAvgFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - partialAvgFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - // (2+3+4) / 3 - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(3), IsTrue) - - // merge two partial results. - err := finalAvgFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalAvgFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalAvgFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - // (10 + 9) / 8 - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(2.375), IsTrue) } diff --git a/executor/aggfuncs/func_bitfuncs_test.go b/executor/aggfuncs/func_bitfuncs_test.go new file mode 100644 index 0000000000..c213b65f2e --- /dev/null +++ b/executor/aggfuncs/func_bitfuncs_test.go @@ -0,0 +1,31 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" +) + +func (s *testSuite) TestMergePartialResult4BitFuncs(c *C) { + tests := []aggMergeTest{ + buildAggMergeTester(ast.AggFuncBitAnd, mysql.TypeLonglong, 5, 0, 0, 0), + buildAggMergeTester(ast.AggFuncBitOr, mysql.TypeLonglong, 5, 7, 7, 7), + buildAggMergeTester(ast.AggFuncBitXor, mysql.TypeLonglong, 5, 4, 5, 1), + } + for _, test := range tests { + s.testMergePartialResult(c, test) + } +} diff --git a/executor/aggfuncs/func_count_test.go b/executor/aggfuncs/func_count_test.go index d1eb85ac61..2685153eec 100644 --- a/executor/aggfuncs/func_count_test.go +++ b/executor/aggfuncs/func_count_test.go @@ -17,50 +17,9 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/executor/aggfuncs" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/aggregation" - "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/chunk" ) func (s *testSuite) TestMergePartialResult4Count(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendInt64(0, i) - } - iter := chunk.NewIterator4Chunk(srcChk) - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeLong), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncCount, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build count func for partial phase. - partialCountFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialCountFunc.AllocPartialResult() - - // build final func for final phase. - finalCountFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalCountFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialCountFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) - } - partialCountFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetInt64(0), Equals, int64(5)) - - // suppose there are two partial workers. - partialPr2 := partialPr1 - - // merge two partial results. - err := finalCountFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalCountFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalCountFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - c.Assert(resultChk.GetRow(0).GetInt64(0), Equals, int64(10)) + tester := buildAggMergeTester(ast.AggFuncCount, mysql.TypeLonglong, 5, 5, 3, 8) + s.testMergePartialResult(c, tester) } diff --git a/executor/aggfuncs/func_first_row_test.go b/executor/aggfuncs/func_first_row_test.go new file mode 100644 index 0000000000..3f3e8b23de --- /dev/null +++ b/executor/aggfuncs/func_first_row_test.go @@ -0,0 +1,40 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs_test + +import ( + "time" + + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/types/json" +) + +func (s *testSuite) TestMergePartialResult4FirstRow(c *C) { + tests := []aggMergeTest{ + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeLonglong, 5, 0, 2, 0), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeFloat, 5, 0.0, 2.0, 0.0), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeDouble, 5, 0.0, 2.0, 0.0), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeNewDecimal, 5, types.NewDecFromInt(0), types.NewDecFromInt(2), types.NewDecFromInt(0)), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeString, 5, "0", "2", "0"), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeDate, 5, types.TimeFromDays(0), types.TimeFromDays(2), types.TimeFromDays(0)), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeDuration, 5, types.Duration{Duration: time.Duration(0)}, types.Duration{Duration: time.Duration(2)}, types.Duration{Duration: time.Duration(0)}), + buildAggMergeTester(ast.AggFuncFirstRow, mysql.TypeJSON, 5, json.CreateBinary(int64(0)), json.CreateBinary(int64(2)), json.CreateBinary(int64(0))), + } + for _, test := range tests { + s.testMergePartialResult(c, test) + } +} diff --git a/executor/aggfuncs/func_group_concat_test.go b/executor/aggfuncs/func_group_concat_test.go new file mode 100644 index 0000000000..404120e3f6 --- /dev/null +++ b/executor/aggfuncs/func_group_concat_test.go @@ -0,0 +1,25 @@ +// Copyright 2019 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package aggfuncs_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/mysql" +) + +func (s *testSuite) TestMergePartialResult4GroupConcat(c *C) { + test := buildAggMergeTester(ast.AggFuncGroupConcat, mysql.TypeString, 5, "0 1 2 3 4", "2 3 4", "0 1 2 3 4 2 3 4") + s.testMergePartialResult(c, test) +} diff --git a/executor/aggfuncs/func_max_min_test.go b/executor/aggfuncs/func_max_min_test.go index dc407797d3..2f4601dd83 100644 --- a/executor/aggfuncs/func_max_min_test.go +++ b/executor/aggfuncs/func_max_min_test.go @@ -14,211 +14,40 @@ package aggfuncs_test import ( + "time" + . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/executor/aggfuncs" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/chunk" + "github.com/pingcap/tidb/types/json" ) -func (s *testSuite) TestMergePartialResult4MaxDecimal(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendMyDecimal(0, types.NewDecFromInt(i)) +func (s *testSuite) TestMergePartialResult4MaxMin(c *C) { + unsignedType := types.NewFieldType(mysql.TypeLonglong) + unsignedType.Flag |= mysql.UnsignedFlag + tests := []aggMergeTest{ + buildAggMergeTester(ast.AggFuncMax, mysql.TypeLonglong, 5, 4, 4, 4), + buildAggMergeTesterWithFieldType(ast.AggFuncMax, unsignedType, 5, 4, 4, 4), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeFloat, 5, 4.0, 4.0, 4.0), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeDouble, 5, 4.0, 4.0, 4.0), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeNewDecimal, 5, types.NewDecFromInt(4), types.NewDecFromInt(4), types.NewDecFromInt(4)), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeString, 5, "4", "4", "4"), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeDate, 5, types.TimeFromDays(4), types.TimeFromDays(4), types.TimeFromDays(4)), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeDuration, 5, types.Duration{Duration: time.Duration(4)}, types.Duration{Duration: time.Duration(4)}, types.Duration{Duration: time.Duration(4)}), + buildAggMergeTester(ast.AggFuncMax, mysql.TypeJSON, 5, json.CreateBinary(int64(4)), json.CreateBinary(int64(4)), json.CreateBinary(int64(4))), + + buildAggMergeTester(ast.AggFuncMin, mysql.TypeLonglong, 5, 0, 2, 0), + buildAggMergeTesterWithFieldType(ast.AggFuncMin, unsignedType, 5, 0, 2, 0), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeFloat, 5, 0.0, 2.0, 0.0), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeDouble, 5, 0.0, 2.0, 0.0), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeNewDecimal, 5, types.NewDecFromInt(0), types.NewDecFromInt(2), types.NewDecFromInt(0)), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeString, 5, "0", "2", "0"), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeDate, 5, types.TimeFromDays(0), types.TimeFromDays(2), types.TimeFromDays(0)), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeDuration, 5, types.Duration{Duration: time.Duration(0)}, types.Duration{Duration: time.Duration(2)}, types.Duration{Duration: time.Duration(0)}), + buildAggMergeTester(ast.AggFuncMin, mysql.TypeJSON, 5, json.CreateBinary(int64(0)), json.CreateBinary(int64(2)), json.CreateBinary(int64(0))), } - iter := chunk.NewIterator4Chunk(srcChk) - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeNewDecimal), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncMax, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build max func for partial phase. - partialMaxFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialMaxFunc.AllocPartialResult() - partialPr2 := partialMaxFunc.AllocPartialResult() - - // build final func for final phase. - finalMaxFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalMaxFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialMaxFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) + for _, test := range tests { + s.testMergePartialResult(c, test) } - partialMaxFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(4)) == 0, IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialMaxFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - partialMaxFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(4)) == 0, IsTrue) - - // merge two partial results. - err := finalMaxFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalMaxFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalMaxFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(4)) == 0, IsTrue) -} - -func (s *testSuite) TestMergePartialResult4MaxFloat(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendFloat64(0, float64(i)) - } - iter := chunk.NewIterator4Chunk(srcChk) - - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncMax, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build max func for partial phase. - partialMaxFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialMaxFunc.AllocPartialResult() - partialPr2 := partialMaxFunc.AllocPartialResult() - - // build final func for final phase. - finalMaxFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalMaxFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialMaxFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) - } - partialMaxFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(4), IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialMaxFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - partialMaxFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(4), IsTrue) - - // merge two partial results. - err := finalMaxFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalMaxFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalMaxFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(4), IsTrue) -} - -func (s *testSuite) TestMergePartialResult4MinDecimal(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendMyDecimal(0, types.NewDecFromInt(i)) - } - iter := chunk.NewIterator4Chunk(srcChk) - - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeNewDecimal), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncMin, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build min func for partial phase. - partialMinFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialMinFunc.AllocPartialResult() - partialPr2 := partialMinFunc.AllocPartialResult() - - // build final func for final phase. - finalMinFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalMinFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialMinFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) - } - partialMinFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(0)) == 0, IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialMinFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - partialMinFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - // Min in [2,3,4] -> 2 - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(2)) == 0, IsTrue) - - // merge two partial results. - err := finalMinFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalMinFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalMinFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - // Min in [0,1,2,3,4] -> 0 - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(0)) == 0, IsTrue) -} - -func (s *testSuite) TestMergePartialResult4MinFloat(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendFloat64(0, float64(i)) - } - iter := chunk.NewIterator4Chunk(srcChk) - - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncMin, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build min func for partial phase. - partialMinFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialMinFunc.AllocPartialResult() - partialPr2 := partialMinFunc.AllocPartialResult() - - // build final func for final phase. - finalMinFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalMinFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialMinFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) - } - partialMinFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(0), IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialMinFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - partialMinFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - // Min in [2.0,3.0,4.0] -> 0 - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(2), IsTrue) - - // merge two partial results. - err := finalMinFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalMinFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalMinFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - // Min in [0.0,1.0,2.0,3.0,4.0] -> 0 - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(0), IsTrue) } diff --git a/executor/aggfuncs/func_sum_test.go b/executor/aggfuncs/func_sum_test.go index 4f6fdff5f7..463e9570c7 100644 --- a/executor/aggfuncs/func_sum_test.go +++ b/executor/aggfuncs/func_sum_test.go @@ -17,113 +17,15 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" - "github.com/pingcap/tidb/executor/aggfuncs" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/expression/aggregation" "github.com/pingcap/tidb/types" - "github.com/pingcap/tidb/util/chunk" ) -func (s *testSuite) TestMergePartialResult4SumDecimal(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendMyDecimal(0, types.NewDecFromInt(i)) +func (s *testSuite) TestMergePartialResult4Sum(c *C) { + tests := []aggMergeTest{ + buildAggMergeTester(ast.AggFuncSum, mysql.TypeNewDecimal, 5, types.NewDecFromInt(10), types.NewDecFromInt(9), types.NewDecFromInt(19)), + buildAggMergeTester(ast.AggFuncSum, mysql.TypeDouble, 5, 10.0, 9.0, 19.0), } - iter := chunk.NewIterator4Chunk(srcChk) - - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeLonglong), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncSum, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build sum func for partial phase. - partialSumFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialSumFunc.AllocPartialResult() - partialPr2 := partialSumFunc.AllocPartialResult() - - // build final func for final phase. - finalAvgFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalAvgFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeNewDecimal)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialSumFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) + for _, test := range tests { + s.testMergePartialResult(c, test) } - // 0+1+2+3+4 - partialSumFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(10)) == 0, IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialSumFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - // 2+3+4 - partialSumFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(9)) == 0, IsTrue) - - // merge two partial results. - err := finalAvgFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalAvgFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalAvgFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - // 10+9 - c.Assert(resultChk.GetRow(0).GetMyDecimal(0).Compare(types.NewDecFromInt(19)) == 0, IsTrue) -} - -func (s *testSuite) TestMergePartialResult4SumFloat(c *C) { - srcChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 5) - for i := int64(0); i < 5; i++ { - srcChk.AppendFloat64(0, float64(i)) - } - iter := chunk.NewIterator4Chunk(srcChk) - - args := []expression.Expression{&expression.Column{RetType: types.NewFieldType(mysql.TypeDouble), Index: 0}} - desc := aggregation.NewAggFuncDesc(s.ctx, ast.AggFuncSum, args, false) - partialDesc, finalDesc := desc.Split([]int{0}) - - // build sum func for partial phase. - partialSumFunc := aggfuncs.Build(s.ctx, partialDesc, 0) - partialPr1 := partialSumFunc.AllocPartialResult() - partialPr2 := partialSumFunc.AllocPartialResult() - - // build final func for final phase. - finalAvgFunc := aggfuncs.Build(s.ctx, finalDesc, 0) - finalPr := finalAvgFunc.AllocPartialResult() - resultChk := chunk.NewChunkWithCapacity([]*types.FieldType{types.NewFieldType(mysql.TypeDouble)}, 1) - - // update partial result. - for row := iter.Begin(); row != iter.End(); row = iter.Next() { - partialSumFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr1) - } - partialSumFunc.AppendFinalResult2Chunk(s.ctx, partialPr1, resultChk) - // (0+1+2+3+4) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(10), IsTrue) - - iter.Begin() - iter.Next() - for row := iter.Next(); row != iter.End(); row = iter.Next() { - partialSumFunc.UpdatePartialResult(s.ctx, []chunk.Row{row}, partialPr2) - } - resultChk.Reset() - partialSumFunc.AppendFinalResult2Chunk(s.ctx, partialPr2, resultChk) - // (2+3+4) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(9), IsTrue) - - // merge two partial results. - err := finalAvgFunc.MergePartialResult(s.ctx, partialPr1, finalPr) - c.Assert(err, IsNil) - err = finalAvgFunc.MergePartialResult(s.ctx, partialPr2, finalPr) - c.Assert(err, IsNil) - - resultChk.Reset() - err = finalAvgFunc.AppendFinalResult2Chunk(s.ctx, finalPr, resultChk) - c.Assert(err, IsNil) - // (10 + 9) - c.Assert(resultChk.GetRow(0).GetFloat64(0) == float64(19), IsTrue) }