*: Support sum in new plan

This commit is contained in:
shenli
2016-01-19 17:57:03 +08:00
committed by Shen Li
parent 150acb9cc2
commit 686dc63352
7 changed files with 139 additions and 1 deletions

View File

@ -411,6 +411,8 @@ func (n *AggregateFuncExpr) Update() error {
return n.updateCount()
case AggFuncFirstRow:
return n.updateFirstRow()
case AggFuncSum:
return n.updateSum()
}
return nil
}
@ -467,6 +469,32 @@ func (n *AggregateFuncExpr) updateFirstRow() error {
return nil
}
func (n *AggregateFuncExpr) updateSum() error {
ctx := n.GetContext()
a := n.Args[0]
value := a.GetValue()
if value == nil {
return nil
}
if n.Distinct {
d, err := ctx.distinctChecker.Check([]interface{}{value})
if err != nil {
return errors.Trace(err)
}
if !d {
return nil
}
}
sum := ctx.Value
var err error
sum, err = types.CalculateSum(sum, value)
if err != nil {
return errors.Trace(err)
}
ctx.Value = sum
return nil
}
// AggregateFuncExtractor visits Expr tree.
// It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr.
type AggregateFuncExtractor struct {

View File

@ -2,6 +2,7 @@ package ast
import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/mysql"
)
var _ = Suite(&testFunctionsSuite{})
@ -101,3 +102,46 @@ func (ts *testFunctionsSuite) TestAggFuncCount(c *C) {
ctx = agg.GetContext()
c.Assert(ctx.Count, Equals, int64(2))
}
func (ts *testFunctionsSuite) TestAggFuncSum(c *C) {
args := make([]ExprNode, 1)
// sum with distinct
agg := &AggregateFuncExpr{
Args: args,
F: AggFuncSum,
Distinct: true,
}
agg.CurrentGroup = "xx"
expr := NewValueExpr(1)
expr1 := NewValueExpr(nil)
expr2 := NewValueExpr(1)
exprs := []ExprNode{expr, expr1, expr2}
for _, e := range exprs {
args[0] = e
agg.Update()
}
ctx := agg.GetContext()
expect, _ := mysql.ConvertToDecimal(1)
v, ok := ctx.Value.(mysql.Decimal)
c.Assert(ok, IsTrue)
c.Assert(v.Equals(expect), IsTrue)
// sum without distinct
agg = &AggregateFuncExpr{
Args: args,
F: AggFuncSum,
}
agg.CurrentGroup = "xx"
expr = NewValueExpr(2)
expr1 = NewValueExpr(nil)
expr2 = NewValueExpr(2)
exprs = []ExprNode{expr, expr1, expr2}
for _, e := range exprs {
args[0] = e
agg.Update()
}
ctx = agg.GetContext()
expect, _ = mysql.ConvertToDecimal(4)
v, ok = ctx.Value.(mysql.Decimal)
c.Assert(ok, IsTrue)
c.Assert(v.Equals(expect), IsTrue)
}

View File

@ -960,6 +960,8 @@ func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
e.evalAggCount(v)
case ast.AggFuncFirstRow:
e.evalAggFirstRow(v)
case ast.AggFuncSum:
e.evalAggSum(v)
}
return e.err == nil
}
@ -973,3 +975,8 @@ func (e *Evaluator) evalAggFirstRow(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
v.SetValue(ctx.Value)
}
func (e *Evaluator) evalAggSum(v *ast.AggregateFuncExpr) {
ctx := v.GetContext()
v.SetValue(ctx.Value)
}

View File

@ -90,6 +90,7 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) {
fn := strings.ToLower(ti.F)
switch fn {
case ast.AggFuncCount:
case ast.AggFuncSum:
default:
c.unsupported = true
}

View File

@ -306,7 +306,13 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
}
return nr.resolveColumnInResultFields(cn, ctx.fieldList)
found := nr.resolveColumnInResultFields(cn, ctx.fieldList)
if !found || nr.Err != nil {
return found
}
if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
nr.Err = errors.New("Groupby identifier can not refer to aggregate function.")
}
}
// Resolve from table first, then from select list.
found := nr.resolveColumnInTableSources(cn, ctx.tables)
@ -327,6 +333,9 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
// We should restore its Refer.
cn.Refer = r
}
if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
nr.Err = errors.New("Groupby identifier can not refer to aggregate function.")
}
return true
}
return found

View File

@ -133,6 +133,11 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
x.SetType(ft)
case "sum":
ft := types.NewFieldType(mysql.TypeNewDecimal)
ft.Charset = charset.CharsetBin
ft.Collate = charset.CollationBin
x.SetType(ft)
}
}

View File

@ -15,6 +15,9 @@ package types
import (
"math"
"github.com/juju/errors"
"github.com/pingcap/tidb/mysql"
)
// RoundFloat rounds float val to the nearest integer value with float64 format, like GNU rint function.
@ -67,3 +70,44 @@ func TruncateFloat(f float64, flen int, decimal int) (float64, error) {
return f, nil
}
// CalculateSum adds v to sum.
func CalculateSum(sum interface{}, v interface{}) (interface{}, error) {
// for avg and sum calculation
// avg and sum use decimal for integer and decimal type, use float for others
// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
var (
data interface{}
err error
)
v = RawData(v)
switch y := v.(type) {
case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
data, err = mysql.ConvertToDecimal(v)
case mysql.Decimal:
data = y
case nil:
data = nil
default:
data, err = ToFloat64(v)
}
if err != nil {
return nil, err
}
if data == nil {
return sum, nil
}
data = RawData(data)
switch x := sum.(type) {
case nil:
return data, nil
case float64:
return x + data.(float64), nil
case mysql.Decimal:
return x.Add(data.(mysql.Decimal)), nil
default:
return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x)
}
}