executor: support window function cume_dist (#9619)
This commit is contained in:
@ -69,6 +69,8 @@ func BuildWindowFunctions(ctx sessionctx.Context, windowFuncDesc *aggregation.Ag
|
||||
return buildFirstValue(windowFuncDesc, ordinal)
|
||||
case ast.WindowFuncLastValue:
|
||||
return buildLastValue(windowFuncDesc, ordinal)
|
||||
case ast.WindowFuncCumeDist:
|
||||
return buildCumeDist(ordinal, orderByCols)
|
||||
default:
|
||||
return Build(ctx, windowFuncDesc, ordinal)
|
||||
}
|
||||
@ -345,11 +347,7 @@ func buildRank(ordinal int, orderByCols []*expression.Column, isDense bool) AggF
|
||||
base := baseAggFunc{
|
||||
ordinal: ordinal,
|
||||
}
|
||||
r := &rank{baseAggFunc: base, isDense: isDense}
|
||||
for _, col := range orderByCols {
|
||||
r.cmpFuncs = append(r.cmpFuncs, chunk.GetCompareFunc(col.RetType))
|
||||
r.colIdx = append(r.colIdx, col.Index)
|
||||
}
|
||||
r := &rank{baseAggFunc: base, isDense: isDense, rowComparer: buildRowComparer(orderByCols)}
|
||||
return r
|
||||
}
|
||||
|
||||
@ -368,3 +366,11 @@ func buildLastValue(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
|
||||
}
|
||||
return &lastValue{baseAggFunc: base, tp: aggFuncDesc.RetTp}
|
||||
}
|
||||
|
||||
func buildCumeDist(ordinal int, orderByCols []*expression.Column) AggFunc {
|
||||
base := baseAggFunc{
|
||||
ordinal: ordinal,
|
||||
}
|
||||
r := &cumeDist{baseAggFunc: base, rowComparer: buildRowComparer(orderByCols)}
|
||||
return r
|
||||
}
|
||||
|
||||
58
executor/aggfuncs/func_cume_dist.go
Normal file
58
executor/aggfuncs/func_cume_dist.go
Normal file
@ -0,0 +1,58 @@
|
||||
// Copyright 2019 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package aggfuncs
|
||||
|
||||
import (
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/util/chunk"
|
||||
)
|
||||
|
||||
type cumeDist struct {
|
||||
baseAggFunc
|
||||
rowComparer
|
||||
}
|
||||
|
||||
type partialResult4CumeDist struct {
|
||||
curIdx int
|
||||
lastRank int
|
||||
rows []chunk.Row
|
||||
}
|
||||
|
||||
func (r *cumeDist) AllocPartialResult() PartialResult {
|
||||
return PartialResult(&partialResult4Rank{})
|
||||
}
|
||||
|
||||
func (r *cumeDist) ResetPartialResult(pr PartialResult) {
|
||||
p := (*partialResult4Rank)(pr)
|
||||
p.curIdx = 0
|
||||
p.lastRank = 0
|
||||
p.rows = p.rows[:0]
|
||||
}
|
||||
|
||||
func (r *cumeDist) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.Row, pr PartialResult) error {
|
||||
p := (*partialResult4CumeDist)(pr)
|
||||
p.rows = append(p.rows, rowsInGroup...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *cumeDist) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
|
||||
p := (*partialResult4CumeDist)(pr)
|
||||
numRows := len(p.rows)
|
||||
for p.lastRank < numRows && r.compareRows(p.rows[p.curIdx], p.rows[p.lastRank]) == 0 {
|
||||
p.lastRank++
|
||||
}
|
||||
p.curIdx++
|
||||
chk.AppendFloat64(r.ordinal, float64(p.lastRank)/float64(numRows))
|
||||
return nil
|
||||
}
|
||||
@ -14,15 +14,15 @@
|
||||
package aggfuncs
|
||||
|
||||
import (
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/util/chunk"
|
||||
)
|
||||
|
||||
type rank struct {
|
||||
baseAggFunc
|
||||
isDense bool
|
||||
cmpFuncs []chunk.CompareFunc
|
||||
colIdx []int
|
||||
isDense bool
|
||||
rowComparer
|
||||
}
|
||||
|
||||
type partialResult4Rank struct {
|
||||
@ -48,16 +48,6 @@ func (r *rank) UpdatePartialResult(sctx sessionctx.Context, rowsInGroup []chunk.
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rank) compareRows(prev, curr chunk.Row) int {
|
||||
for i, idx := range r.colIdx {
|
||||
res := r.cmpFuncs[i](prev, idx, curr, idx)
|
||||
if res != 0 {
|
||||
return res
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult, chk *chunk.Chunk) error {
|
||||
p := (*partialResult4Rank)(pr)
|
||||
p.curIdx++
|
||||
@ -78,3 +68,29 @@ func (r *rank) AppendFinalResult2Chunk(sctx sessionctx.Context, pr PartialResult
|
||||
chk.AppendInt64(r.ordinal, p.lastRank)
|
||||
return nil
|
||||
}
|
||||
|
||||
type rowComparer struct {
|
||||
cmpFuncs []chunk.CompareFunc
|
||||
colIdx []int
|
||||
}
|
||||
|
||||
func buildRowComparer(cols []*expression.Column) rowComparer {
|
||||
rc := rowComparer{}
|
||||
rc.colIdx = make([]int, 0, len(cols))
|
||||
rc.cmpFuncs = make([]chunk.CompareFunc, 0, len(cols))
|
||||
for _, col := range cols {
|
||||
rc.cmpFuncs = append(rc.cmpFuncs, chunk.GetCompareFunc(col.RetType))
|
||||
rc.colIdx = append(rc.colIdx, col.Index)
|
||||
}
|
||||
return rc
|
||||
}
|
||||
|
||||
func (rc *rowComparer) compareRows(prev, curr chunk.Row) int {
|
||||
for i, idx := range rc.colIdx {
|
||||
res := rc.cmpFuncs[i](prev, idx, curr, idx)
|
||||
if res != 0 {
|
||||
return res
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@ -103,4 +103,11 @@ func (s *testSuite2) TestWindowFunctions(c *C) {
|
||||
result = tk.MustQuery("select a, first_value(rand(0)) over(), last_value(rand(0)) over() from t")
|
||||
result.Check(testkit.Rows("1 0.9451961492941164 0.05434383959970039", "1 0.9451961492941164 0.05434383959970039",
|
||||
"2 0.9451961492941164 0.05434383959970039", "2 0.9451961492941164 0.05434383959970039"))
|
||||
|
||||
result = tk.MustQuery("select a, b, cume_dist() over() from t")
|
||||
result.Check(testkit.Rows("1 1 1", "1 2 1", "2 1 1", "2 2 1"))
|
||||
result = tk.MustQuery("select a, b, cume_dist() over(order by a) from t")
|
||||
result.Check(testkit.Rows("1 1 0.5", "1 2 0.5", "2 1 1", "2 2 1"))
|
||||
result = tk.MustQuery("select a, b, cume_dist() over(order by a, b) from t")
|
||||
result.Check(testkit.Rows("1 1 0.25", "1 2 0.5", "2 1 0.75", "2 2 1"))
|
||||
}
|
||||
|
||||
@ -98,6 +98,8 @@ func (a *baseFuncDesc) typeInfer(ctx sessionctx.Context) {
|
||||
a.typeInfer4BitFuncs(ctx)
|
||||
case ast.WindowFuncRowNumber, ast.WindowFuncRank, ast.WindowFuncDenseRank:
|
||||
a.typeInfer4NumberFuncs()
|
||||
case ast.WindowFuncCumeDist:
|
||||
a.typeInfer4CumeDist()
|
||||
default:
|
||||
panic("unsupported agg function: " + a.Name)
|
||||
}
|
||||
@ -193,6 +195,11 @@ func (a *baseFuncDesc) typeInfer4NumberFuncs() {
|
||||
types.SetBinChsClnFlag(a.RetTp)
|
||||
}
|
||||
|
||||
func (a *baseFuncDesc) typeInfer4CumeDist() {
|
||||
a.RetTp = types.NewFieldType(mysql.TypeDouble)
|
||||
a.RetTp.Flen, a.RetTp.Decimal = mysql.MaxRealWidth, mysql.NotFixedDec
|
||||
}
|
||||
|
||||
// GetDefaultValue gets the default value when the function's input is null.
|
||||
// According to MySQL, default values of the function are listed as follows:
|
||||
// e.g.
|
||||
|
||||
Reference in New Issue
Block a user