*: Support sum in new plan
This commit is contained in:
@ -411,6 +411,8 @@ func (n *AggregateFuncExpr) Update() error {
|
||||
return n.updateCount()
|
||||
case AggFuncFirstRow:
|
||||
return n.updateFirstRow()
|
||||
case AggFuncSum:
|
||||
return n.updateSum()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -467,6 +469,32 @@ func (n *AggregateFuncExpr) updateFirstRow() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *AggregateFuncExpr) updateSum() error {
|
||||
ctx := n.GetContext()
|
||||
a := n.Args[0]
|
||||
value := a.GetValue()
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
if n.Distinct {
|
||||
d, err := ctx.distinctChecker.Check([]interface{}{value})
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
if !d {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
sum := ctx.Value
|
||||
var err error
|
||||
sum, err = types.CalculateSum(sum, value)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
ctx.Value = sum
|
||||
return nil
|
||||
}
|
||||
|
||||
// AggregateFuncExtractor visits Expr tree.
|
||||
// It converts ColunmNameExpr to AggregateFuncExpr and collects AggregateFuncExpr.
|
||||
type AggregateFuncExtractor struct {
|
||||
|
||||
@ -2,6 +2,7 @@ package ast
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
var _ = Suite(&testFunctionsSuite{})
|
||||
@ -101,3 +102,46 @@ func (ts *testFunctionsSuite) TestAggFuncCount(c *C) {
|
||||
ctx = agg.GetContext()
|
||||
c.Assert(ctx.Count, Equals, int64(2))
|
||||
}
|
||||
|
||||
func (ts *testFunctionsSuite) TestAggFuncSum(c *C) {
|
||||
args := make([]ExprNode, 1)
|
||||
// sum with distinct
|
||||
agg := &AggregateFuncExpr{
|
||||
Args: args,
|
||||
F: AggFuncSum,
|
||||
Distinct: true,
|
||||
}
|
||||
agg.CurrentGroup = "xx"
|
||||
expr := NewValueExpr(1)
|
||||
expr1 := NewValueExpr(nil)
|
||||
expr2 := NewValueExpr(1)
|
||||
exprs := []ExprNode{expr, expr1, expr2}
|
||||
for _, e := range exprs {
|
||||
args[0] = e
|
||||
agg.Update()
|
||||
}
|
||||
ctx := agg.GetContext()
|
||||
expect, _ := mysql.ConvertToDecimal(1)
|
||||
v, ok := ctx.Value.(mysql.Decimal)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(v.Equals(expect), IsTrue)
|
||||
// sum without distinct
|
||||
agg = &AggregateFuncExpr{
|
||||
Args: args,
|
||||
F: AggFuncSum,
|
||||
}
|
||||
agg.CurrentGroup = "xx"
|
||||
expr = NewValueExpr(2)
|
||||
expr1 = NewValueExpr(nil)
|
||||
expr2 = NewValueExpr(2)
|
||||
exprs = []ExprNode{expr, expr1, expr2}
|
||||
for _, e := range exprs {
|
||||
args[0] = e
|
||||
agg.Update()
|
||||
}
|
||||
ctx = agg.GetContext()
|
||||
expect, _ = mysql.ConvertToDecimal(4)
|
||||
v, ok = ctx.Value.(mysql.Decimal)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(v.Equals(expect), IsTrue)
|
||||
}
|
||||
|
||||
@ -960,6 +960,8 @@ func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
|
||||
e.evalAggCount(v)
|
||||
case ast.AggFuncFirstRow:
|
||||
e.evalAggFirstRow(v)
|
||||
case ast.AggFuncSum:
|
||||
e.evalAggSum(v)
|
||||
}
|
||||
return e.err == nil
|
||||
}
|
||||
@ -973,3 +975,8 @@ func (e *Evaluator) evalAggFirstRow(v *ast.AggregateFuncExpr) {
|
||||
ctx := v.GetContext()
|
||||
v.SetValue(ctx.Value)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evalAggSum(v *ast.AggregateFuncExpr) {
|
||||
ctx := v.GetContext()
|
||||
v.SetValue(ctx.Value)
|
||||
}
|
||||
|
||||
@ -90,6 +90,7 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) {
|
||||
fn := strings.ToLower(ti.F)
|
||||
switch fn {
|
||||
case ast.AggFuncCount:
|
||||
case ast.AggFuncSum:
|
||||
default:
|
||||
c.unsupported = true
|
||||
}
|
||||
|
||||
@ -306,7 +306,13 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
|
||||
if nr.resolveColumnInTableSources(cn, ctx.tables) {
|
||||
return true
|
||||
}
|
||||
return nr.resolveColumnInResultFields(cn, ctx.fieldList)
|
||||
found := nr.resolveColumnInResultFields(cn, ctx.fieldList)
|
||||
if !found || nr.Err != nil {
|
||||
return found
|
||||
}
|
||||
if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
|
||||
nr.Err = errors.New("Groupby identifier can not refer to aggregate function.")
|
||||
}
|
||||
}
|
||||
// Resolve from table first, then from select list.
|
||||
found := nr.resolveColumnInTableSources(cn, ctx.tables)
|
||||
@ -327,6 +333,9 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
|
||||
// We should restore its Refer.
|
||||
cn.Refer = r
|
||||
}
|
||||
if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
|
||||
nr.Err = errors.New("Groupby identifier can not refer to aggregate function.")
|
||||
}
|
||||
return true
|
||||
}
|
||||
return found
|
||||
|
||||
@ -133,6 +133,11 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
|
||||
ft.Charset = charset.CharsetBin
|
||||
ft.Collate = charset.CollationBin
|
||||
x.SetType(ft)
|
||||
case "sum":
|
||||
ft := types.NewFieldType(mysql.TypeNewDecimal)
|
||||
ft.Charset = charset.CharsetBin
|
||||
ft.Collate = charset.CollationBin
|
||||
x.SetType(ft)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -15,6 +15,9 @@ package types
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
// RoundFloat rounds float val to the nearest integer value with float64 format, like GNU rint function.
|
||||
@ -67,3 +70,44 @@ func TruncateFloat(f float64, flen int, decimal int) (float64, error) {
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// CalculateSum adds v to sum.
|
||||
func CalculateSum(sum interface{}, v interface{}) (interface{}, error) {
|
||||
// for avg and sum calculation
|
||||
// avg and sum use decimal for integer and decimal type, use float for others
|
||||
// see https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
|
||||
var (
|
||||
data interface{}
|
||||
err error
|
||||
)
|
||||
|
||||
v = RawData(v)
|
||||
switch y := v.(type) {
|
||||
case int, uint, int8, uint8, int16, uint16, int32, uint32, int64, uint64:
|
||||
data, err = mysql.ConvertToDecimal(v)
|
||||
case mysql.Decimal:
|
||||
data = y
|
||||
case nil:
|
||||
data = nil
|
||||
default:
|
||||
data, err = ToFloat64(v)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if data == nil {
|
||||
return sum, nil
|
||||
}
|
||||
data = RawData(data)
|
||||
switch x := sum.(type) {
|
||||
case nil:
|
||||
return data, nil
|
||||
case float64:
|
||||
return x + data.(float64), nil
|
||||
case mysql.Decimal:
|
||||
return x.Add(data.(mysql.Decimal)), nil
|
||||
default:
|
||||
return nil, errors.Errorf("invalid value %v(%T) for aggregate", x, x)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user