From 686dc63352d2ad044347f17c14e390e4795b0184 Mon Sep 17 00:00:00 2001 From: shenli Date: Tue, 19 Jan 2016 17:57:03 +0800 Subject: [PATCH] *: Support sum in new plan --- ast/functions.go | 28 ++++++++++++++++++++ ast/functions_test.go | 44 ++++++++++++++++++++++++++++++++ optimizer/evaluator/evaluator.go | 7 +++++ optimizer/optimizer.go | 1 + optimizer/resolver.go | 11 +++++++- optimizer/typeinferer.go | 5 ++++ util/types/helper.go | 44 ++++++++++++++++++++++++++++++++ 7 files changed, 139 insertions(+), 1 deletion(-) diff --git a/ast/functions.go b/ast/functions.go index 6e275e9068..a46fd68051 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -411,6 +411,8 @@ func (n *AggregateFuncExpr) Update() error { return n.updateCount() case AggFuncFirstRow: return n.updateFirstRow() + case AggFuncSum: + return n.updateSum() } return nil } @@ -467,6 +469,32 @@ func (n *AggregateFuncExpr) updateFirstRow() error { return nil } +func (n *AggregateFuncExpr) updateSum() error { + ctx := n.GetContext() + a := n.Args[0] + value := a.GetValue() + if value == nil { + return nil + } + if n.Distinct { + d, err := ctx.distinctChecker.Check([]interface{}{value}) + if err != nil { + return errors.Trace(err) + } + if !d { + return nil + } + } + sum := ctx.Value + var err error + sum, err = types.CalculateSum(sum, value) + if err != nil { + return errors.Trace(err) + } + ctx.Value = sum + return nil +} + // AggregateFuncExtractor visits Expr tree. // It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr. type AggregateFuncExtractor struct { diff --git a/ast/functions_test.go b/ast/functions_test.go index 3dbeb342fd..a19c48dfd2 100644 --- a/ast/functions_test.go +++ b/ast/functions_test.go @@ -2,6 +2,7 @@ package ast import ( . "github.com/pingcap/check" + "github.com/pingcap/tidb/mysql" ) var _ = Suite(&testFunctionsSuite{}) @@ -101,3 +102,46 @@ func (ts *testFunctionsSuite) TestAggFuncCount(c *C) { ctx = agg.GetContext() c.Assert(ctx.Count, Equals, int64(2)) } + +func (ts *testFunctionsSuite) TestAggFuncSum(c *C) { + args := make([]ExprNode, 1) + // sum with distinct + agg := &AggregateFuncExpr{ + Args: args, + F: AggFuncSum, + Distinct: true, + } + agg.CurrentGroup = "xx" + expr := NewValueExpr(1) + expr1 := NewValueExpr(nil) + expr2 := NewValueExpr(1) + exprs := []ExprNode{expr, expr1, expr2} + for _, e := range exprs { + args[0] = e + agg.Update() + } + ctx := agg.GetContext() + expect, _ := mysql.ConvertToDecimal(1) + v, ok := ctx.Value.(mysql.Decimal) + c.Assert(ok, IsTrue) + c.Assert(v.Equals(expect), IsTrue) + // sum without distinct + agg = &AggregateFuncExpr{ + Args: args, + F: AggFuncSum, + } + agg.CurrentGroup = "xx" + expr = NewValueExpr(2) + expr1 = NewValueExpr(nil) + expr2 = NewValueExpr(2) + exprs = []ExprNode{expr, expr1, expr2} + for _, e := range exprs { + args[0] = e + agg.Update() + } + ctx = agg.GetContext() + expect, _ = mysql.ConvertToDecimal(4) + v, ok = ctx.Value.(mysql.Decimal) + c.Assert(ok, IsTrue) + c.Assert(v.Equals(expect), IsTrue) +} diff --git a/optimizer/evaluator/evaluator.go b/optimizer/evaluator/evaluator.go index f4555c188f..eaf8f6fa26 100644 --- a/optimizer/evaluator/evaluator.go +++ b/optimizer/evaluator/evaluator.go @@ -960,6 +960,8 @@ func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { e.evalAggCount(v) case ast.AggFuncFirstRow: e.evalAggFirstRow(v) + case ast.AggFuncSum: + e.evalAggSum(v) } return e.err == nil } @@ -973,3 +975,8 @@ func (e *Evaluator) evalAggFirstRow(v *ast.AggregateFuncExpr) { ctx := v.GetContext() v.SetValue(ctx.Value) } + +func (e *Evaluator) evalAggSum(v *ast.AggregateFuncExpr) { + ctx := v.GetContext() + v.SetValue(ctx.Value) +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 3322412669..84e7dee09e 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -90,6 +90,7 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) { fn := strings.ToLower(ti.F) switch fn { case ast.AggFuncCount: + case ast.AggFuncSum: default: c.unsupported = true } diff --git a/optimizer/resolver.go b/optimizer/resolver.go index 595423f273..def15f976a 100644 --- a/optimizer/resolver.go +++ b/optimizer/resolver.go @@ -306,7 +306,13 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast if nr.resolveColumnInTableSources(cn, ctx.tables) { return true } - return nr.resolveColumnInResultFields(cn, ctx.fieldList) + found := nr.resolveColumnInResultFields(cn, ctx.fieldList) + if !found || nr.Err != nil { + return found + } + if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok { + nr.Err = errors.New("Groupby identifier can not refer to aggregate function.") + } } // Resolve from table first, then from select list. found := nr.resolveColumnInTableSources(cn, ctx.tables) @@ -327,6 +333,9 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast // We should restore its Refer. cn.Refer = r } + if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok { + nr.Err = errors.New("Groupby identifier can not refer to aggregate function.") + } return true } return found diff --git a/optimizer/typeinferer.go b/optimizer/typeinferer.go index e8116a0e2d..e770771aaa 100644 --- a/optimizer/typeinferer.go +++ b/optimizer/typeinferer.go @@ -133,6 +133,11 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) { ft.Charset = charset.CharsetBin ft.Collate = charset.CollationBin x.SetType(ft) + case "sum": + ft := types.NewFieldType(mysql.TypeNewDecimal) + ft.Charset = charset.CharsetBin + ft.Collate = charset.CollationBin + x.SetType(ft) } } diff --git a/util/types/helper.go b/util/types/helper.go index b9b0817f3b..862680d60f 100644 --- a/util/types/helper.go +++ b/util/types/helper.go @@ -15,6 +15,9 @@ package types import ( "math" + + "github.com/juju/errors" + "github.com/pingcap/tidb/mysql" ) // RoundFloat rounds float val to the nearest integer value with float64 format, like GNU rint function. @@ -67,3 +70,44 @@ func TruncateFloat(f float64, flen int, decimal int) (float64, error) { return f, nil } + +// CalculateSum adds v to sum. +func CalculateSum(sum interface{}, v interface{}) (interface{}, error) { + // for avg and sum calculation + // avg and sum use decimal for integer and decimal type, use float for others + // see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html + var ( + data interface{} + err error + ) + + v = RawData(v) + switch y := v.(type) { + case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64: + data, err = mysql.ConvertToDecimal(v) + case mysql.Decimal: + data = y + case nil: + data = nil + default: + data, err = ToFloat64(v) + } + + if err != nil { + return nil, err + } + if data == nil { + return sum, nil + } + data = RawData(data) + switch x := sum.(type) { + case nil: + return data, nil + case float64: + return x + data.(float64), nil + case mysql.Decimal: + return x.Add(data.(mysql.Decimal)), nil + default: + return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x) + } +}