executor: support window function cume_dist (#9619)

This commit is contained in:
Haibin Xie
2019-03-11 19:30:54 +08:00
committed by GitHub
parent ec208f5484
commit f0aca29fff
5 changed files with 112 additions and 18 deletions

View File

@ -69,6 +69,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
return buildFirstValue(windowFuncDesc, ordinal)
case ast.WindowFuncLastValue:
return buildLastValue(windowFuncDesc, ordinal)
case ast.WindowFuncCumeDist:
return buildCumeDist(ordinal, orderByCols)
default:
return Build(ctx, windowFuncDesc, ordinal)
}
@ -345,11 +347,7 @@ func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggF
base := baseAggFunc{
ordinal: ordinal,
}
r := &rank{baseAggFunc: base, isDense: isDense}
for _, col := range orderByCols {
r.cmpFuncs = append(r.cmpFuncs, chunk.GetCompareFunc(col.RetType))
r.colIdx = append(r.colIdx, col.Index)
}
r := &rank{baseAggFunc: base, isDense: isDense, rowComparer: buildRowComparer(orderByCols)}
return r
}
@ -368,3 +366,11 @@ func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
}
return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp}
}
func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc {
base := baseAggFunc{
ordinal: ordinal,
}
r := &cumeDist{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)}
return r
}

View File

@ -0,0 +1,58 @@
// Copyright 2019 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package aggfuncs
import (
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)
type cumeDist struct {
baseAggFunc
rowComparer
}
type partialResult4CumeDist struct {
curIdx int
lastRank int
rows []chunk.Row
}
func (r *cumeDist) AllocPartialResult() PartialResult {
return PartialResult(&partialResult4Rank{})
}
func (r *cumeDist) ResetPartialResult(pr PartialResult) {
p := (*partialResult4Rank)(pr)
p.curIdx = 0
p.lastRank = 0
p.rows = p.rows[:0]
}
func (r *cumeDist) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
p := (*partialResult4CumeDist)(pr)
p.rows = append(p.rows, rowsInGroup...)
return nil
}
func (r *cumeDist) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4CumeDist)(pr)
numRows := len(p.rows)
for p.lastRank < numRows && r.compareRows(p.rows[p.curIdx], p.rows[p.lastRank]) == 0 {
p.lastRank++
}
p.curIdx++
chk.AppendFloat64(r.ordinal, float64(p.lastRank)/float64(numRows))
return nil
}

View File

@ -14,15 +14,15 @@
package aggfuncs
import (
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/util/chunk"
)
type rank struct {
baseAggFunc
isDense bool
cmpFuncs []chunk.CompareFunc
colIdx []int
isDense bool
rowComparer
}
type partialResult4Rank struct {
@ -48,16 +48,6 @@ func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.
return nil
}
func (r *rank) compareRows(prev, curr chunk.Row) int {
for i, idx := range r.colIdx {
res := r.cmpFuncs[i](prev, idx, curr, idx)
if res != 0 {
return res
}
}
return 0
}
func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
p := (*partialResult4Rank)(pr)
p.curIdx++
@ -78,3 +68,29 @@ func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult
chk.AppendInt64(r.ordinal, p.lastRank)
return nil
}
type rowComparer struct {
cmpFuncs []chunk.CompareFunc
colIdx []int
}
func buildRowComparer(cols []*expression.Column) rowComparer {
rc := rowComparer{}
rc.colIdx = make([]int, 0, len(cols))
rc.cmpFuncs = make([]chunk.CompareFunc, 0, len(cols))
for _, col := range cols {
rc.cmpFuncs = append(rc.cmpFuncs, chunk.GetCompareFunc(col.RetType))
rc.colIdx = append(rc.colIdx, col.Index)
}
return rc
}
func (rc *rowComparer) compareRows(prev, curr chunk.Row) int {
for i, idx := range rc.colIdx {
res := rc.cmpFuncs[i](prev, idx, curr, idx)
if res != 0 {
return res
}
}
return 0
}

View File

@ -103,4 +103,11 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t")
result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039",
"2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039"))
result = tk.MustQuery("select a, b, cume_dist() over() from t")
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, cume_dist() over(order by a) from t")
result.Check(testkit.Rows("1 1 0.5", "1 2 0.5", "2 1 1", "2 2 1"))
result = tk.MustQuery("select a, b, cume_dist() over(order by a, b) from t")
result.Check(testkit.Rows("1 1 0.25", "1 2 0.5", "2 1 0.75", "2 2 1"))
}

View File

@ -98,6 +98,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
a.typeInfer4BitFuncs(ctx)
case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
a.typeInfer4NumberFuncs()
case ast.WindowFuncCumeDist:
a.typeInfer4CumeDist()
default:
panic("unsupported agg function: " + a.Name)
}
@ -193,6 +195,11 @@ func (a *baseFuncDesc) typeInfer4NumberFuncs() {
types.SetBinChsClnFlag(a.RetTp)
}
func (a *baseFuncDesc) typeInfer4CumeDist() {
a.RetTp = types.NewFieldType(mysql.TypeDouble)
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
}
// GetDefaultValue gets the default value when the function's input is null.
// According to MySQL, default values of the function are listed as follows:
// e.g.