expression, planner: fix decimal results for aggregate functions (#20017)

This commit is contained in:
dy
2020-11-11 14:46:22 +08:00
committed by GitHub
parent a338e35932
commit a3facd0f71
9 changed files with 82 additions and 17 deletions

View File

@ -183,6 +183,10 @@ type baseAggFunc struct {
// ordinal stores the ordinal of the columns in the output chunk, which is
// used to append the final result of this function.
ordinal int
// frac stores digits of the fractional part of decimals,
// which makes the decimal be the result of type inferring.
frac int
}
func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) {

View File

@ -17,6 +17,7 @@ import (
"fmt"
"strconv"
"github.com/cznic/mathutil"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
@ -247,6 +248,11 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
ordinal: ordinal,
},
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
case aggregation.DedupMode:
return nil
@ -275,6 +281,15 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi
args: aggFuncDesc.Args,
ordinal: ordinal,
}
frac := base.args[0].GetType().Decimal
if len(base.args) == 2 {
frac = base.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
switch aggFuncDesc.Mode {
// Build avg functions which consume the original data and remove the
// duplicated input of the same group.
@ -319,6 +334,11 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc {
args: aggFuncDesc.Args,
ordinal: ordinal,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {
@ -368,6 +388,11 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool)
},
isMax: isMax,
}
frac := base.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
base.frac = mathutil.Min(frac, mysql.MaxDecimalScale)
evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp
if fieldType.Tp == mysql.TypeBit {

View File

@ -16,8 +16,6 @@ package aggfuncs
import (
"unsafe"
"github.com/cznic/mathutil"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
@ -73,15 +71,7 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
if err != nil {
return err
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if len(e.args) == 2 {
frac = e.args[1].GetType().Decimal
}
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
@ -267,12 +257,7 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co
if err != nil {
return err
}
// Make the decimal be the result of type inferring.
frac := e.args[0].GetType().Decimal
if frac == -1 {
frac = mysql.MaxDecimalScale
}
err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven)
err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven)
if err != nil {
return err
}

View File

@ -474,6 +474,10 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, &p.val)
return nil
}

View File

@ -752,6 +752,10 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, &p.val)
return nil
}

View File

@ -166,6 +166,10 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia
chk.AppendNull(e.ordinal)
return nil
}
err := p.val.Round(&p.val, e.frac, types.ModeHalfEven)
if err != nil {
return err
}
chk.AppendMyDecimal(e.ordinal, &p.val)
return nil
}

View File

@ -1142,3 +1142,29 @@ func (s *testSuiteAgg) TestIssue17216(c *C) {
tk.MustExec(`INSERT INTO t1 VALUES (2084,0.02040000000000000000),(35324,0.02190000000000000000),(43760,0.00510000000000000000),(46084,0.01400000000000000000),(46312,0.00560000000000000000),(61632,0.02730000000000000000),(94676,0.00660000000000000000),(102244,0.01810000000000000000),(113144,0.02140000000000000000),(157024,0.02750000000000000000),(157144,0.01750000000000000000),(182076,0.02370000000000000000),(188696,0.02330000000000000000),(833,0.00390000000000000000),(6701,0.00230000000000000000),(8533,0.01690000000000000000),(13801,0.01360000000000000000),(20797,0.00680000000000000000),(36677,0.00550000000000000000),(46305,0.01290000000000000000),(76113,0.00430000000000000000),(76753,0.02400000000000000000),(92393,0.01720000000000000000),(111733,0.02690000000000000000),(152757,0.00250000000000000000),(162393,0.02760000000000000000),(167169,0.00440000000000000000),(168097,0.01360000000000000000),(180309,0.01720000000000000000),(19918,0.02620000000000000000),(58674,0.01820000000000000000),(67454,0.01510000000000000000),(70870,0.02880000000000000000),(89614,0.02530000000000000000),(106742,0.00180000000000000000),(107886,0.01580000000000000000),(147506,0.02230000000000000000),(148366,0.01340000000000000000),(167258,0.01860000000000000000),(194438,0.00500000000000000000),(10307,0.02850000000000000000),(14539,0.02210000000000000000),(27703,0.00050000000000000000),(32495,0.00680000000000000000),(39235,0.01450000000000000000),(52379,0.01640000000000000000),(54551,0.01910000000000000000),(85659,0.02330000000000000000),(104483,0.02670000000000000000),(109911,0.02040000000000000000),(114523,0.02110000000000000000),(119495,0.02120000000000000000),(137603,0.01910000000000000000),(154031,0.02580000000000000000);`)
tk.MustQuery("SELECT count(distinct col1) FROM t1").Check(testkit.Rows("48"))
}
func (s *testSuiteAgg) TestIssue19426(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key, b int)")
tk.MustExec("insert into t values (1, 11), (4, 44), (2, 22), (3, 33)")
tk.MustQuery("select sum(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("110.0"))
tk.MustQuery("select avg(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("27.50000"))
tk.MustQuery("select distinct (case when a <= 0 or a > 1000 then 0.0 else b end) v from t order by v").
Check(testkit.Rows("11.0", "22.0", "33.0", "44.0"))
tk.MustQuery("select group_concat(case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t").
Check(testkit.Rows("44.0,33.0,22.0,11.0"))
tk.MustQuery("select group_concat(a, b, case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t").
Check(testkit.Rows("44444.0,33333.0,22222.0,11111.0"))
tk.MustQuery("select group_concat(distinct case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t").
Check(testkit.Rows("44.0,33.0,22.0,11.0"))
tk.MustQuery("select max(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("44.0"))
tk.MustQuery("select min(case when a <= 0 or a > 1000 then 0.0 else b end) from t").
Check(testkit.Rows("11.0"))
tk.MustQuery("select a, b, sum(case when a < 1000 then b else 0.0 end) over (order by a) from t").
Check(testkit.Rows("1 11 11.0", "2 22 33.0", "3 33 66.0", "4 44 110.0"))
}

View File

@ -358,6 +358,18 @@ var noNeedCastAggFuncs = map[string]struct{}{
ast.AggFuncJsonObjectAgg: {},
}
// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions
// with a cast as decimal function. See issue #19426
func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) {
if a.Name == ast.AggFuncGroupConcat {
for i := 0; i < len(a.Args)-1; i++ {
if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal {
a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp)
}
}
}
}
// WrapCastForAggArgs wraps the args of an aggregate function with a cast function.
func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) {
if len(a.Args) == 0 {

View File

@ -61,6 +61,7 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan {
// since the types of the args are already the expected.
func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) {
for i := range aggFuncs {
aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx)
if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode {
aggFuncs[i].WrapCastForAggArgs(sctx)
}