diff --git a/executor/aggfuncs/aggfuncs.go b/executor/aggfuncs/aggfuncs.go index 279c2e2ed7..a65644a6a1 100644 --- a/executor/aggfuncs/aggfuncs.go +++ b/executor/aggfuncs/aggfuncs.go @@ -183,6 +183,10 @@ type baseAggFunc struct { // ordinal stores the ordinal of the columns in the output chunk, which is // used to append the final result of this function. ordinal int + + // frac stores digits of the fractional part of decimals, + // which makes the decimal be the result of type inferring. + frac int } func (*baseAggFunc) MergePartialResult(sctx sessionctx.Context, src, dst PartialResult) (memDelta int64, err error) { diff --git a/executor/aggfuncs/builder.go b/executor/aggfuncs/builder.go index a8a9780289..ea8b6ec6a9 100644 --- a/executor/aggfuncs/builder.go +++ b/executor/aggfuncs/builder.go @@ -17,6 +17,7 @@ import ( "fmt" "strconv" + "github.com/cznic/mathutil" "github.com/pingcap/parser/ast" "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/expression" @@ -247,6 +248,11 @@ func buildSum(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi ordinal: ordinal, }, } + frac := base.args[0].GetType().Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) switch aggFuncDesc.Mode { case aggregation.DedupMode: return nil @@ -275,6 +281,15 @@ func buildAvg(ctx sessionctx.Context, aggFuncDesc *aggregation.AggFuncDesc, ordi args: aggFuncDesc.Args, ordinal: ordinal, } + frac := base.args[0].GetType().Decimal + if len(base.args) == 2 { + frac = base.args[1].GetType().Decimal + } + if frac == -1 { + frac = mysql.MaxDecimalScale + } + base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) + switch aggFuncDesc.Mode { // Build avg functions which consume the original data and remove the // duplicated input of the same group. @@ -319,6 +334,11 @@ func buildFirstRow(aggFuncDesc *aggregation.AggFuncDesc, ordinal int) AggFunc { args: aggFuncDesc.Args, ordinal: ordinal, } + frac := base.args[0].GetType().Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp if fieldType.Tp == mysql.TypeBit { @@ -368,6 +388,11 @@ func buildMaxMin(aggFuncDesc *aggregation.AggFuncDesc, ordinal int, isMax bool) }, isMax: isMax, } + frac := base.args[0].GetType().Decimal + if frac == -1 { + frac = mysql.MaxDecimalScale + } + base.frac = mathutil.Min(frac, mysql.MaxDecimalScale) evalType, fieldType := aggFuncDesc.RetTp.EvalType(), aggFuncDesc.RetTp if fieldType.Tp == mysql.TypeBit { diff --git a/executor/aggfuncs/func_avg.go b/executor/aggfuncs/func_avg.go index 13cf713544..29c138b102 100644 --- a/executor/aggfuncs/func_avg.go +++ b/executor/aggfuncs/func_avg.go @@ -16,8 +16,6 @@ package aggfuncs import ( "unsafe" - "github.com/cznic/mathutil" - "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" @@ -73,15 +71,7 @@ func (e *baseAvgDecimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par if err != nil { return err } - // Make the decimal be the result of type inferring. - frac := e.args[0].GetType().Decimal - if len(e.args) == 2 { - frac = e.args[1].GetType().Decimal - } - if frac == -1 { - frac = mysql.MaxDecimalScale - } - err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) + err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven) if err != nil { return err } @@ -267,12 +257,7 @@ func (e *avgOriginal4DistinctDecimal) AppendFinalResult2Chunk(sctx sessionctx.Co if err != nil { return err } - // Make the decimal be the result of type inferring. - frac := e.args[0].GetType().Decimal - if frac == -1 { - frac = mysql.MaxDecimalScale - } - err = finalResult.Round(finalResult, mathutil.Min(frac, mysql.MaxDecimalScale), types.ModeHalfEven) + err = finalResult.Round(finalResult, e.frac, types.ModeHalfEven) if err != nil { return err } diff --git a/executor/aggfuncs/func_first_row.go b/executor/aggfuncs/func_first_row.go index 5a6786d95a..5b351f58ca 100644 --- a/executor/aggfuncs/func_first_row.go +++ b/executor/aggfuncs/func_first_row.go @@ -474,6 +474,10 @@ func (e *firstRow4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr P chk.AppendNull(e.ordinal) return nil } + err := p.val.Round(&p.val, e.frac, types.ModeHalfEven) + if err != nil { + return err + } chk.AppendMyDecimal(e.ordinal, &p.val) return nil } diff --git a/executor/aggfuncs/func_max_min.go b/executor/aggfuncs/func_max_min.go index 74f212a90f..e1ac741ae4 100644 --- a/executor/aggfuncs/func_max_min.go +++ b/executor/aggfuncs/func_max_min.go @@ -752,6 +752,10 @@ func (e *maxMin4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Par chk.AppendNull(e.ordinal) return nil } + err := p.val.Round(&p.val, e.frac, types.ModeHalfEven) + if err != nil { + return err + } chk.AppendMyDecimal(e.ordinal, &p.val) return nil } diff --git a/executor/aggfuncs/func_sum.go b/executor/aggfuncs/func_sum.go index 070d97dcec..c4415b482b 100644 --- a/executor/aggfuncs/func_sum.go +++ b/executor/aggfuncs/func_sum.go @@ -166,6 +166,10 @@ func (e *sum4Decimal) AppendFinalResult2Chunk(sctx sessionctx.Context, pr Partia chk.AppendNull(e.ordinal) return nil } + err := p.val.Round(&p.val, e.frac, types.ModeHalfEven) + if err != nil { + return err + } chk.AppendMyDecimal(e.ordinal, &p.val) return nil } diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index f0ce5f11bc..a08e937003 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -1142,3 +1142,29 @@ func (s *testSuiteAgg) TestIssue17216(c *C) { tk.MustExec(`INSERT INTO t1 VALUES (2084,0.02040000000000000000),(35324,0.02190000000000000000),(43760,0.00510000000000000000),(46084,0.01400000000000000000),(46312,0.00560000000000000000),(61632,0.02730000000000000000),(94676,0.00660000000000000000),(102244,0.01810000000000000000),(113144,0.02140000000000000000),(157024,0.02750000000000000000),(157144,0.01750000000000000000),(182076,0.02370000000000000000),(188696,0.02330000000000000000),(833,0.00390000000000000000),(6701,0.00230000000000000000),(8533,0.01690000000000000000),(13801,0.01360000000000000000),(20797,0.00680000000000000000),(36677,0.00550000000000000000),(46305,0.01290000000000000000),(76113,0.00430000000000000000),(76753,0.02400000000000000000),(92393,0.01720000000000000000),(111733,0.02690000000000000000),(152757,0.00250000000000000000),(162393,0.02760000000000000000),(167169,0.00440000000000000000),(168097,0.01360000000000000000),(180309,0.01720000000000000000),(19918,0.02620000000000000000),(58674,0.01820000000000000000),(67454,0.01510000000000000000),(70870,0.02880000000000000000),(89614,0.02530000000000000000),(106742,0.00180000000000000000),(107886,0.01580000000000000000),(147506,0.02230000000000000000),(148366,0.01340000000000000000),(167258,0.01860000000000000000),(194438,0.00500000000000000000),(10307,0.02850000000000000000),(14539,0.02210000000000000000),(27703,0.00050000000000000000),(32495,0.00680000000000000000),(39235,0.01450000000000000000),(52379,0.01640000000000000000),(54551,0.01910000000000000000),(85659,0.02330000000000000000),(104483,0.02670000000000000000),(109911,0.02040000000000000000),(114523,0.02110000000000000000),(119495,0.02120000000000000000),(137603,0.01910000000000000000),(154031,0.02580000000000000000);`) tk.MustQuery("SELECT count(distinct col1) FROM t1").Check(testkit.Rows("48")) } + +func (s *testSuiteAgg) TestIssue19426(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a int primary key, b int)") + tk.MustExec("insert into t values (1, 11), (4, 44), (2, 22), (3, 33)") + tk.MustQuery("select sum(case when a <= 0 or a > 1000 then 0.0 else b end) from t"). + Check(testkit.Rows("110.0")) + tk.MustQuery("select avg(case when a <= 0 or a > 1000 then 0.0 else b end) from t"). + Check(testkit.Rows("27.50000")) + tk.MustQuery("select distinct (case when a <= 0 or a > 1000 then 0.0 else b end) v from t order by v"). + Check(testkit.Rows("11.0", "22.0", "33.0", "44.0")) + tk.MustQuery("select group_concat(case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t"). + Check(testkit.Rows("44.0,33.0,22.0,11.0")) + tk.MustQuery("select group_concat(a, b, case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t"). + Check(testkit.Rows("44444.0,33333.0,22222.0,11111.0")) + tk.MustQuery("select group_concat(distinct case when a <= 0 or a > 1000 then 0.0 else b end order by -a) from t"). + Check(testkit.Rows("44.0,33.0,22.0,11.0")) + tk.MustQuery("select max(case when a <= 0 or a > 1000 then 0.0 else b end) from t"). + Check(testkit.Rows("44.0")) + tk.MustQuery("select min(case when a <= 0 or a > 1000 then 0.0 else b end) from t"). + Check(testkit.Rows("11.0")) + tk.MustQuery("select a, b, sum(case when a < 1000 then b else 0.0 end) over (order by a) from t"). + Check(testkit.Rows("1 11 11.0", "2 22 33.0", "3 33 66.0", "4 44 110.0")) +} diff --git a/expression/aggregation/base_func.go b/expression/aggregation/base_func.go index b9758e044e..f773296553 100644 --- a/expression/aggregation/base_func.go +++ b/expression/aggregation/base_func.go @@ -358,6 +358,18 @@ var noNeedCastAggFuncs = map[string]struct{}{ ast.AggFuncJsonObjectAgg: {}, } +// WrapCastAsDecimalForAggArgs wraps the args of some specific aggregate functions +// with a cast as decimal function. See issue #19426 +func (a *baseFuncDesc) WrapCastAsDecimalForAggArgs(ctx sessionctx.Context) { + if a.Name == ast.AggFuncGroupConcat { + for i := 0; i < len(a.Args)-1; i++ { + if tp := a.Args[i].GetType(); tp.Tp == mysql.TypeNewDecimal { + a.Args[i] = expression.BuildCastFunction(ctx, a.Args[i], tp) + } + } + } +} + // WrapCastForAggArgs wraps the args of an aggregate function with a cast function. func (a *baseFuncDesc) WrapCastForAggArgs(ctx sessionctx.Context) { if len(a.Args) == 0 { diff --git a/planner/core/rule_inject_extra_projection.go b/planner/core/rule_inject_extra_projection.go index a9e2b7ed11..977c73761f 100644 --- a/planner/core/rule_inject_extra_projection.go +++ b/planner/core/rule_inject_extra_projection.go @@ -61,6 +61,7 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan { // since the types of the args are already the expected. func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc) { for i := range aggFuncs { + aggFuncs[i].WrapCastAsDecimalForAggArgs(sctx) if aggFuncs[i].Mode != aggregation.FinalMode && aggFuncs[i].Mode != aggregation.Partial2Mode { aggFuncs[i].WrapCastForAggArgs(sctx) }