From b6b9b8a6df800015bca44c5034c2eef11f5ba419 Mon Sep 17 00:00:00 2001 From: mmyj Date: Wed, 9 Sep 2020 14:31:41 +0800 Subject: [PATCH] exector, planner: Improve the performance of the aggFuncMaxMin by using sliding window (#16819) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * temp * temp * unit test * executor: Improve the performance of `aggFuncMin` by using sliding window * executor: Improve the performance of `aggFuncMin` by using sliding window * executor: Improve the performance of `aggFuncMin` by using sliding window add ut * fix: ResetPartialResult * add maxMin4Time.Slide and maxMin4Duration.Slide * add benchmark * add var `dirty` * add a comment * fix Decimal * newMyDecimalMaxMinQueue * newMyDecimalMaxMinQueue * newMyDecimalMaxMinQueue * newMyDecimalMaxMinQueue * newMyDecimalMaxMinQueue * implementing the maxMinQueue using heap * fix import check * fix import check * fix check * remove PushMyDecimal * refactor maxMinHeap * 尝试优化heap * 尝试优化heap * fix benchmark * fix:* * solved pr comments * solved pr comments * fix * fix * fix * 先回家 * maxMin4IntSliding * fix import * fix * new builder * fix ut * fix * fix * fix ut * fix ut * fix fmt * add benchmark * fix * fix * lazyload * lazyload * fix frame * fix check_dev * add a unit test * sliding aggFunc * sliding aggFunc * fix * move ut * fix dev_check * fix dev_check * resolved comments * refactor ut * refactor ut * refactor ut Co-authored-by: Yuanjia Zhang --- executor/aggfuncs/aggfunc_test.go | 22 +- executor/aggfuncs/builder.go | 30 ++ executor/aggfuncs/func_max_min.go | 612 +++++++++++++++++++++++++ executor/aggfuncs/func_max_min_test.go | 101 ++++ executor/benchmark_test.go | 23 + executor/window_test.go | 48 ++ planner/core/logical_plan_builder.go | 12 +- 7 files changed, 845 insertions(+), 3 deletions(-) diff --git a/executor/aggfuncs/aggfunc_test.go b/executor/aggfuncs/aggfunc_test.go index 46d9b7a7a5..37b3b60d92 100644 --- a/executor/aggfuncs/aggfunc_test.go +++ b/executor/aggfuncs/aggfunc_test.go @@ -26,11 +26,16 @@ import ( "github.com/pingcap/parser" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/executor/aggfuncs" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/aggregation" + "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/planner/util" + "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/store/mockstore" + "github.com/pingcap/tidb/store/mockstore/cluster" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" @@ -50,13 +55,28 @@ func TestT(t *testing.T) { type testSuite struct { *parser.Parser - ctx sessionctx.Context + ctx sessionctx.Context + cluster cluster.Cluster + store kv.Storage + domain *domain.Domain } func (s *testSuite) SetUpSuite(c *C) { s.Parser = parser.New() s.ctx = mock.NewContext() s.ctx.GetSessionVars().StmtCtx.TimeZone = time.Local + store, err := mockstore.NewMockStore( + mockstore.WithClusterInspector(func(c cluster.Cluster) { + mockstore.BootstrapWithSingleStore(c) + s.cluster = c + }), + ) + c.Assert(err, IsNil) + s.store = store + d, err := session.BootstrapSession(s.store) + c.Assert(err, IsNil) + d.SetStatsUpdating(true) + s.domain = d } func (s *testSuite) TearDownSuite(c *C) { diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index e994b7399b..c4c00de085 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -88,6 +88,11 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag return buildLead(windowFuncDesc, ordinal) case ast.WindowFuncLag: return buildLag(windowFuncDesc, ordinal) + case ast.AggFuncMax: + // The max/min aggFunc using in the window function will using the sliding window algo. + return buildMaxMinInWindowFunction(windowFuncDesc, ordinal, true) + case ast.AggFuncMin: + return buildMaxMinInWindowFunction(windowFuncDesc, ordinal, false) default: return Build(ctx, windowFuncDesc, ordinal) } @@ -361,6 +366,31 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) return nil } +// buildMaxMin builds the AggFunc implementation for function "MAX" and "MIN" using by window function. +func buildMaxMinInWindowFunction(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) AggFunc { + base := buildMaxMin(aggFuncDesc, ordinal, isMax) + // build max/min aggFunc for window function using sliding window + switch baseAggFunc := base.(type) { + case *maxMin4Int: + return &maxMin4IntSliding{*baseAggFunc} + case *maxMin4Uint: + return &maxMin4UintSliding{*baseAggFunc} + case *maxMin4Float32: + return &maxMin4Float32Sliding{*baseAggFunc} + case *maxMin4Float64: + return &maxMin4Float64Sliding{*baseAggFunc} + case *maxMin4Decimal: + return &maxMin4DecimalSliding{*baseAggFunc} + case *maxMin4String: + return &maxMin4StringSliding{*baseAggFunc} + case *maxMin4Time: + return &maxMin4TimeSliding{*baseAggFunc} + case *maxMin4Duration: + return &maxMin4DurationSliding{*baseAggFunc} + } + return base +} + // buildGroupConcat builds the AggFunc implementation for function "GROUP_CONCAT". func buildGroupConcat(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { switch aggFuncDesc.Mode { diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go index 708e27fc12..528c68e498 100644 --- a/executor/aggfuncs/func_max_min.go +++ b/executor/aggfuncs/func_max_min.go @@ -14,6 +14,8 @@ package aggfuncs import ( + "container/heap" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" @@ -21,47 +23,164 @@ import ( "github.com/pingcap/tidb/util/stringutil" ) +type maxMinHeap struct { + data []interface{} + h heap.Interface + varSet map[interface{}]int64 + isMax bool + cmpFunc func(i, j interface{}) int +} + +func newMaxMinHeap(isMax bool, cmpFunc func(i, j interface{}) int) *maxMinHeap { + h := &maxMinHeap{ + data: make([]interface{}, 0), + varSet: make(map[interface{}]int64), + isMax: isMax, + cmpFunc: cmpFunc, + } + return h +} + +func (h *maxMinHeap) Len() int { return len(h.data) } +func (h *maxMinHeap) Less(i, j int) bool { + if h.isMax { + return h.cmpFunc(h.data[i], h.data[j]) > 0 + } + return h.cmpFunc(h.data[i], h.data[j]) < 0 +} +func (h *maxMinHeap) Swap(i, j int) { h.data[i], h.data[j] = h.data[j], h.data[i] } + +func (h *maxMinHeap) Push(x interface{}) { + h.data = append(h.data, x) +} +func (h *maxMinHeap) Pop() interface{} { + old := h.data + n := len(old) + x := old[n-1] + h.data = old[0 : n-1] + return x +} + +func (h *maxMinHeap) Reset() { + h.data = h.data[:0] + h.varSet = make(map[interface{}]int64) +} +func (h *maxMinHeap) Append(val interface{}) { + h.varSet[val]++ + if h.varSet[val] == 1 { + heap.Push(h, val) + } +} +func (h *maxMinHeap) Remove(val interface{}) { + if h.varSet[val] > 0 { + h.varSet[val]-- + } else { + panic("remove a not exist value") + } +} +func (h *maxMinHeap) Top() (val interface{}, isEmpty bool) { +retry: + if h.Len() == 0 { + return nil, true + } + top := h.data[0] + if h.varSet[top] == 0 { + _ = heap.Pop(h) + goto retry + } + return top, false +} + +func (h *maxMinHeap) AppendMyDecimal(val types.MyDecimal) error { + key, err := val.ToHashKey() + if err != nil { + return err + } + h.varSet[string(key)]++ + if h.varSet[string(key)] == 1 { + heap.Push(h, val) + } + return nil +} +func (h *maxMinHeap) RemoveMyDecimal(val types.MyDecimal) error { + key, err := val.ToHashKey() + if err != nil { + return err + } + if h.varSet[string(key)] > 0 { + h.varSet[string(key)]-- + } else { + panic("remove a not exist value") + } + return nil +} +func (h *maxMinHeap) TopDecimal() (val types.MyDecimal, isEmpty bool) { +retry: + if h.Len() == 0 { + return types.MyDecimal{}, true + } + top := h.data[0].(types.MyDecimal) + key, err := top.ToHashKey() + if err != nil { + panic(err) + } + if h.varSet[string(key)] == 0 { + _ = heap.Pop(h) + goto retry + } + return top, false +} + type partialResult4MaxMinInt struct { val int64 // isNull is used to indicates: // 1. whether the partial result is the initialization value which should not be compared during evaluation; // 2. whether all the values of arg are all null, if so, we should return null as the default value for MAX/MIN. isNull bool + // maxMinHeap is an ordered queue, using to evaluate the maximum or minimum value in a sliding window. + heap *maxMinHeap } type partialResult4MaxMinUint struct { val uint64 isNull bool + heap *maxMinHeap } type partialResult4MaxMinDecimal struct { val types.MyDecimal isNull bool + heap *maxMinHeap } type partialResult4MaxMinFloat32 struct { val float32 isNull bool + heap *maxMinHeap } type partialResult4MaxMinFloat64 struct { val float64 isNull bool + heap *maxMinHeap } type partialResult4Time struct { val types.Time isNull bool + heap *maxMinHeap } type partialResult4MaxMinDuration struct { val types.Duration isNull bool + heap *maxMinHeap } type partialResult4MaxMinString struct { val string isNull bool + heap *maxMinHeap } type partialResult4MaxMinJSON struct { @@ -92,6 +211,9 @@ type maxMin4Int struct { func (e *maxMin4Int) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinInt) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + return types.CompareInt64(i.(int64), j.(int64)) + }) return PartialResult(p), 0 } @@ -99,6 +221,7 @@ func (e *maxMin4Int) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinInt)(pr) p.val = 0 p.isNull = true + p.heap.Reset() } func (e *maxMin4Int) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -148,6 +271,62 @@ func (e *maxMin4Int) MergePartialResult(sctx sessionctx.Context, src, dst Partia return 0, nil } +type maxMin4IntSliding struct { + maxMin4Int +} + +func (e *maxMin4IntSliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinInt)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(int64) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4IntSliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinInt)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalInt(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(input) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalInt(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(int64) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4Uint struct { baseMaxMinAggFunc } @@ -155,6 +334,9 @@ type maxMin4Uint struct { func (e *maxMin4Uint) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinUint) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + return types.CompareUint64(i.(uint64), j.(uint64)) + }) return PartialResult(p), 0 } @@ -162,6 +344,7 @@ func (e *maxMin4Uint) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinUint)(pr) p.val = 0 p.isNull = true + p.heap.Reset() } func (e *maxMin4Uint) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -212,6 +395,62 @@ func (e *maxMin4Uint) MergePartialResult(sctx sessionctx.Context, src, dst Parti return 0, nil } +type maxMin4UintSliding struct { + maxMin4Uint +} + +func (e *maxMin4UintSliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinUint)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalInt(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(uint64(input)) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(uint64) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4UintSliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinUint)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalInt(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(uint64(input)) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalInt(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(uint64(input)) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(uint64) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + // maxMin4Float32 gets a float32 input and returns a float32 result. type maxMin4Float32 struct { baseMaxMinAggFunc @@ -220,6 +459,9 @@ type maxMin4Float32 struct { func (e *maxMin4Float32) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinFloat32) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + return types.CompareFloat64(float64(i.(float32)), float64(j.(float32))) + }) return PartialResult(p), 0 } @@ -227,6 +469,7 @@ func (e *maxMin4Float32) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinFloat32)(pr) p.val = 0 p.isNull = true + p.heap.Reset() } func (e *maxMin4Float32) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -277,6 +520,62 @@ func (e *maxMin4Float32) MergePartialResult(sctx sessionctx.Context, src, dst Pa return 0, nil } +type maxMin4Float32Sliding struct { + maxMin4Float32 +} + +func (e *maxMin4Float32Sliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinFloat32)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(float32(input)) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(float32) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4Float32Sliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinFloat32)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalReal(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(float32(input)) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalReal(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(float32(input)) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(float32) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4Float64 struct { baseMaxMinAggFunc } @@ -284,6 +583,9 @@ type maxMin4Float64 struct { func (e *maxMin4Float64) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinFloat64) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + return types.CompareFloat64(i.(float64), j.(float64)) + }) return PartialResult(p), 0 } @@ -291,6 +593,7 @@ func (e *maxMin4Float64) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinFloat64)(pr) p.val = 0 p.isNull = true + p.heap.Reset() } func (e *maxMin4Float64) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -340,6 +643,62 @@ func (e *maxMin4Float64) MergePartialResult(sctx sessionctx.Context, src, dst Pa return 0, nil } +type maxMin4Float64Sliding struct { + maxMin4Float64 +} + +func (e *maxMin4Float64Sliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinFloat64)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalReal(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(float64) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4Float64Sliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinFloat64)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalReal(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(input) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalReal(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(float64) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4Decimal struct { baseMaxMinAggFunc } @@ -347,12 +706,18 @@ type maxMin4Decimal struct { func (e *maxMin4Decimal) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinDecimal) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + src := i.(types.MyDecimal) + dst := j.(types.MyDecimal) + return src.Compare(&dst) + }) return PartialResult(p), 0 } func (e *maxMin4Decimal) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinDecimal)(pr) p.isNull = true + p.heap.Reset() } func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -404,6 +769,68 @@ func (e *maxMin4Decimal) MergePartialResult(sctx sessionctx.Context, src, dst Pa return 0, nil } +type maxMin4DecimalSliding struct { + maxMin4Decimal +} + +func (e *maxMin4DecimalSliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinDecimal)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalDecimal(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + if err := p.heap.AppendMyDecimal(*input); err != nil { + return 0, err + } + } + if val, isEmpty := p.heap.TopDecimal(); !isEmpty { + p.val = val + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4DecimalSliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinDecimal)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalDecimal(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + if err := p.heap.AppendMyDecimal(*input); err != nil { + return err + } + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalDecimal(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + if err := p.heap.RemoveMyDecimal(*input); err != nil { + return err + } + } + if val, isEmpty := p.heap.TopDecimal(); !isEmpty { + p.val = val + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4String struct { baseMaxMinAggFunc retTp *types.FieldType @@ -412,12 +839,17 @@ type maxMin4String struct { func (e *maxMin4String) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinString) p.isNull = true + tp := e.args[0].GetType() + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + return types.CompareString(i.(string), j.(string), tp.Collate) + }) return PartialResult(p), 0 } func (e *maxMin4String) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinString)(pr) p.isNull = true + p.heap.Reset() } func (e *maxMin4String) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -475,6 +907,62 @@ func (e *maxMin4String) MergePartialResult(sctx sessionctx.Context, src, dst Par return 0, nil } +type maxMin4StringSliding struct { + maxMin4String +} + +func (e *maxMin4StringSliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinString)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalString(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(string) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4StringSliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinString)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalString(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(input) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalString(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(string) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4Time struct { baseMaxMinAggFunc } @@ -482,12 +970,18 @@ type maxMin4Time struct { func (e *maxMin4Time) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4Time) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + src := i.(types.Time) + dst := j.(types.Time) + return src.Compare(dst) + }) return PartialResult(p), 0 } func (e *maxMin4Time) ResetPartialResult(pr PartialResult) { p := (*partialResult4Time)(pr) p.isNull = true + p.heap.Reset() } func (e *maxMin4Time) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -539,6 +1033,62 @@ func (e *maxMin4Time) MergePartialResult(sctx sessionctx.Context, src, dst Parti return 0, nil } +type maxMin4TimeSliding struct { + maxMin4Time +} + +func (e *maxMin4TimeSliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4Time)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalTime(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(types.Time) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4TimeSliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4Time)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalTime(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(input) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalTime(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(types.Time) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4Duration struct { baseMaxMinAggFunc } @@ -546,12 +1096,18 @@ type maxMin4Duration struct { func (e *maxMin4Duration) AllocPartialResult() (pr PartialResult, memDelta int64) { p := new(partialResult4MaxMinDuration) p.isNull = true + p.heap = newMaxMinHeap(e.isMax, func(i, j interface{}) int { + src := i.(types.Duration) + dst := j.(types.Duration) + return src.Compare(dst) + }) return PartialResult(p), 0 } func (e *maxMin4Duration) ResetPartialResult(pr PartialResult) { p := (*partialResult4MaxMinDuration)(pr) p.isNull = true + p.heap.Reset() } func (e *maxMin4Duration) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error { @@ -603,6 +1159,62 @@ func (e *maxMin4Duration) MergePartialResult(sctx sessionctx.Context, src, dst P return 0, nil } +type maxMin4DurationSliding struct { + maxMin4Duration +} + +func (e *maxMin4DurationSliding) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) (memDelta int64, err error) { + p := (*partialResult4MaxMinDuration)(pr) + for _, row := range rowsInGroup { + input, isNull, err := e.args[0].EvalDuration(sctx, row) + if err != nil { + return 0, err + } + if isNull { + continue + } + p.heap.Append(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(types.Duration) + p.isNull = false + } else { + p.isNull = true + } + return 0, nil +} + +func (e *maxMin4DurationSliding) Slide(sctx sessionctx.Context, rows []chunk.Row, lastStart, lastEnd uint64, shiftStart, shiftEnd uint64, pr PartialResult) error { + p := (*partialResult4MaxMinDuration)(pr) + for i := uint64(0); i < shiftEnd; i++ { + input, isNull, err := e.args[0].EvalDuration(sctx, rows[lastEnd+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Append(input) + } + for i := uint64(0); i < shiftStart; i++ { + input, isNull, err := e.args[0].EvalDuration(sctx, rows[lastStart+i]) + if err != nil { + return err + } + if isNull { + continue + } + p.heap.Remove(input) + } + if val, isEmpty := p.heap.Top(); !isEmpty { + p.val = val.(types.Duration) + p.isNull = false + } else { + p.isNull = true + } + return nil +} + type maxMin4JSON struct { baseMaxMinAggFunc } diff --git a/executor/aggfuncs/func_max_min_test.go b/executor/aggfuncs/func_max_min_test.go index 29e9792778..b7188fe285 100644 --- a/executor/aggfuncs/func_max_min_test.go +++ b/executor/aggfuncs/func_max_min_test.go @@ -14,6 +14,7 @@ package aggfuncs_test import ( + "fmt" "time" . "github.com/pingcap/check" @@ -21,6 +22,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" + "github.com/pingcap/tidb/util/testkit" ) func (s *testSuite) TestMergePartialResult4MaxMin(c *C) { @@ -93,3 +95,102 @@ func (s *testSuite) TestMaxMin(c *C) { s.testAggFunc(c, test) } } + +type maxSlidingWindowTestCase struct { + rowType string + insertValue string + expect []string + orderByExpect []string + orderBy bool + frameType ast.FrameType +} + +func testMaxSlidingWindow(tk *testkit.TestKit, tc maxSlidingWindowTestCase) { + tk.MustExec(fmt.Sprintf("CREATE TABLE t (a %s);", tc.rowType)) + tk.MustExec(fmt.Sprintf("insert into t values %s;", tc.insertValue)) + var orderBy string + if tc.orderBy { + orderBy = "ORDER BY a" + } + var result *testkit.Result + switch tc.frameType { + case ast.Rows: + result = tk.MustQuery(fmt.Sprintf("SELECT max(a) OVER (%s ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t;", orderBy)) + case ast.Ranges: + result = tk.MustQuery(fmt.Sprintf("SELECT max(a) OVER (%s RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM t;", orderBy)) + default: + result = tk.MustQuery(fmt.Sprintf("SELECT max(a) OVER (%s) FROM t;", orderBy)) + if tc.orderBy { + result.Check(testkit.Rows(tc.orderByExpect...)) + return + } + } + result.Check(testkit.Rows(tc.expect...)) +} + +func (s *testSuite) TestMaxSlidingWindow(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + testCases := []maxSlidingWindowTestCase{ + { + rowType: "bigint", + insertValue: "(1), (3), (2)", + expect: []string{"3", "3", "3"}, + orderByExpect: []string{"1", "2", "3"}, + }, + { + rowType: "float", + insertValue: "(1.1), (3.3), (2.2)", + expect: []string{"3.3", "3.3", "3.3"}, + orderByExpect: []string{"1.1", "2.2", "3.3"}, + }, + { + rowType: "double", + insertValue: "(1.1), (3.3), (2.2)", + expect: []string{"3.3", "3.3", "3.3"}, + orderByExpect: []string{"1.1", "2.2", "3.3"}, + }, + { + rowType: "decimal(5, 2)", + insertValue: "(1.1), (3.3), (2.2)", + expect: []string{"3.30", "3.30", "3.30"}, + orderByExpect: []string{"1.10", "2.20", "3.30"}, + }, + { + rowType: "text", + insertValue: "('1.1'), ('3.3'), ('2.2')", + expect: []string{"3.3", "3.3", "3.3"}, + orderByExpect: []string{"1.1", "2.2", "3.3"}, + }, + { + rowType: "time", + insertValue: "('00:00:00'), ('03:00:00'), ('02:00:00')", + expect: []string{"03:00:00", "03:00:00", "03:00:00"}, + orderByExpect: []string{"00:00:00", "02:00:00", "03:00:00"}, + }, + { + rowType: "date", + insertValue: "('2020-09-08'), ('2022-09-10'), ('2020-09-10')", + expect: []string{"2022-09-10", "2022-09-10", "2022-09-10"}, + orderByExpect: []string{"2020-09-08", "2020-09-10", "2022-09-10"}, + }, + { + rowType: "datetime", + insertValue: "('2020-09-08 02:00:00'), ('2022-09-10 00:00:00'), ('2020-09-10 00:00:00')", + expect: []string{"2022-09-10 00:00:00", "2022-09-10 00:00:00", "2022-09-10 00:00:00"}, + orderByExpect: []string{"2020-09-08 02:00:00", "2020-09-10 00:00:00", "2022-09-10 00:00:00"}, + }, + } + + orderBy := []bool{false, true} + frameType := []ast.FrameType{ast.Rows, ast.Ranges, -1} + for _, o := range orderBy { + for _, f := range frameType { + for _, tc := range testCases { + tc.frameType = f + tc.orderBy = o + tk.MustExec("drop table if exists t;") + testMaxSlidingWindow(tk, tc) + } + } + } +} diff --git a/executor/benchmark_test.go b/executor/benchmark_test.go index 5c65634f45..8afe822ce8 100644 --- a/executor/benchmark_test.go +++ b/executor/benchmark_test.go @@ -449,6 +449,8 @@ func buildWindowExecutor(ctx sessionctx.Context, windowFunc string, funcs int, f args = append(args, src.Schema().Columns[0]) case ast.AggFuncBitXor: args = append(args, src.Schema().Columns[0]) + case ast.AggFuncMax, ast.AggFuncMin: + args = append(args, src.Schema().Columns[0]) default: args = append(args, partitionBy[0]) } @@ -672,6 +674,23 @@ func BenchmarkWindowFunctionsWithFrame(b *testing.B) { } } +func BenchmarkWindowFunctionsAggWindowProcessorAboutFrame(b *testing.B) { + b.ReportAllocs() + windowFunc := ast.AggFuncMax + frame := &core.WindowFrame{Type: ast.Rows, Start: &core.FrameBound{UnBounded: true}, End: &core.FrameBound{UnBounded: true}} + cas := defaultWindowTestCase() + cas.rows = 10000 + cas.ndv = 10 + cas.concurrency = 1 + cas.dataSourceSorted = false + cas.windowFunc = windowFunc + cas.numFunc = 1 + cas.frame = frame + b.Run(fmt.Sprintf("%v", cas), func(b *testing.B) { + benchmarkWindowExecWithCase(b, cas) + }) +} + func baseBenchmarkWindowFunctionsWithSlidingWindow(b *testing.B, frameType ast.FrameType) { b.ReportAllocs() windowFuncs := []struct { @@ -684,6 +703,10 @@ func baseBenchmarkWindowFunctionsWithSlidingWindow(b *testing.B, frameType ast.F {ast.AggFuncAvg, mysql.TypeFloat}, {ast.AggFuncAvg, mysql.TypeNewDecimal}, {ast.AggFuncBitXor, mysql.TypeLong}, + {ast.AggFuncMax, mysql.TypeLong}, + {ast.AggFuncMax, mysql.TypeFloat}, + {ast.AggFuncMin, mysql.TypeLong}, + {ast.AggFuncMin, mysql.TypeFloat}, } row := 100000 ndv := 100 diff --git a/executor/window_test.go b/executor/window_test.go index a3a31ca315..219113cf38 100644 --- a/executor/window_test.go +++ b/executor/window_test.go @@ -353,4 +353,52 @@ func baseTestSlidingWindowFunctions(tk *testkit.TestKit) { result.Check(testkit.Rows("M 0", "F 4", "F 0", "F 2", "M 1", " 1", " 1")) result = tk.MustQuery("SELECT sex, BIT_XOR(id) OVER (ORDER BY id DESC RANGE BETWEEN 1 PRECEDING and 2 FOLLOWING) FROM t;") result.Check(testkit.Rows(" 1", " 1", "M 2", "F 0", "F 4", "F 0", "M 3")) + + // MIN ROWS + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id ROWS BETWEEN 1 FOLLOWING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 2", "F 3", "F 4", "F 5", "M 10", " 11", " ")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id ROWS BETWEEN 3 FOLLOWING and 1 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M ", "F ", "F ", "F ", "M ", " ", " ")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id ROWS BETWEEN 2 PRECEDING and 1 PRECEDING) FROM t;") + result.Check(testkit.Rows("M ", "F 1", "F 1", "F 2", "M 3", " 4", " 5")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING and 3 PRECEDING) FROM t;") + result.Check(testkit.Rows("M ", "F ", "F ", "F ", "M ", " ", " ")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id ROWS BETWEEN UNBOUNDED PRECEDING and 1 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 1", "F 1", "F 1", "F 1", "M 1", " 1", " 1")) + + // MIN RANGE + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id RANGE BETWEEN 1 FOLLOWING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 2", "F 3", "F 4", "F 5", "M ", " 11", " ")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id RANGE BETWEEN 3 FOLLOWING and 1 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M ", "F ", "F ", "F ", "M ", " ", " ")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id RANGE BETWEEN 2 PRECEDING and 1 PRECEDING) FROM t;") + result.Check(testkit.Rows("M ", "F 1", "F 1", "F 2", "M 3", " ", " 10")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id RANGE BETWEEN 1 PRECEDING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 1", "F 1", "F 2", "F 3", "M 4", " 10", " 10")) + result = tk.MustQuery("SELECT sex, MIN(id) OVER (ORDER BY id DESC RANGE BETWEEN 1 PRECEDING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows(" 10", " 10", "M 3", "F 2", "F 1", "F 1", "M 1")) + + // MAX ROWS + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id ROWS BETWEEN 1 FOLLOWING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 3", "F 4", "F 5", "F 10", "M 11", " 11", " ")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id ROWS BETWEEN 3 FOLLOWING and 1 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M ", "F ", "F ", "F ", "M ", " ", " ")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id ROWS BETWEEN 2 PRECEDING and 1 PRECEDING) FROM t;") + result.Check(testkit.Rows("M ", "F 1", "F 2", "F 3", "M 4", " 5", " 10")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id ROWS BETWEEN 1 PRECEDING and 3 PRECEDING) FROM t;") + result.Check(testkit.Rows("M ", "F ", "F ", "F ", "M ", " ", " ")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id ROWS BETWEEN UNBOUNDED PRECEDING and 1 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 2", "F 3", "F 4", "F 5", "M 10", " 11", " 11")) + + // MAX RANGE + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id RANGE BETWEEN 1 FOLLOWING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 3", "F 4", "F 5", "F 5", "M ", " 11", " ")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id RANGE BETWEEN 3 FOLLOWING and 1 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M ", "F ", "F ", "F ", "M ", " ", " ")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id RANGE BETWEEN 2 PRECEDING and 1 PRECEDING) FROM t;") + result.Check(testkit.Rows("M ", "F 1", "F 2", "F 3", "M 4", " ", " 10")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id RANGE BETWEEN 1 PRECEDING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows("M 3", "F 4", "F 5", "F 5", "M 5", " 11", " 11")) + result = tk.MustQuery("SELECT sex, MAX(id) OVER (ORDER BY id DESC RANGE BETWEEN 1 PRECEDING and 2 FOLLOWING) FROM t;") + result.Check(testkit.Rows(" 11", " 11", "M 5", "F 5", "F 4", "F 3", "M 2")) } diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 0142d4395d..beabe64679 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -4457,8 +4457,9 @@ func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, windowFuncName st needFrame := aggregation.NeedFrame(windowFuncName) // According to MySQL, In the absence of a frame clause, the default frame depends on whether an ORDER BY clause is present: // (1) With order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"; - // (2) Without order by, the default frame is equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", - // which is the same as an empty frame. + // (2) Without order by, the default frame is includes all partition rows, equivalent to "RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", + // or "ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING", which is the same as an empty frame. + // https://dev.mysql.com/doc/refman/8.0/en/window-functions-frames.html if needFrame && spec.Frame == nil && spec.OrderBy != nil { newSpec := *spec newSpec.Frame = &ast.FrameClause{ @@ -4470,6 +4471,13 @@ func (b *PlanBuilder) handleDefaultFrame(spec *ast.WindowSpec, windowFuncName st } return &newSpec, true } + // "RANGE/ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING" is equivalent to empty frame. + if needFrame && spec.Frame != nil && + spec.Frame.Extent.Start.UnBounded && spec.Frame.Extent.End.UnBounded { + newSpec := *spec + newSpec.Frame = nil + return &newSpec, true + } // For functions that operate on the entire partition, the frame clause will be ignored. if !needFrame && spec.Frame != nil { specName := spec.Name.O