diff --git a/ast/functions.go b/ast/functions.go index 3245780bfb..e49aa17a9b 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -425,7 +425,7 @@ func (n *AggregateFuncExpr) Update() error { return n.updateMaxMin(true) case AggFuncMin: return n.updateMaxMin(false) - case AggFuncSum: + case AggFuncSum, AggFuncAvg: return n.updateSum() case AggFuncGroupConcat: return n.updateGroupConcat() @@ -534,6 +534,7 @@ func (n *AggregateFuncExpr) updateSum() error { if err != nil { return errors.Trace(err) } + ctx.Count++ return nil } diff --git a/optimizer/evaluator/evaluator.go b/optimizer/evaluator/evaluator.go index ea1a8b12e7..0d8821fb32 100644 --- a/optimizer/evaluator/evaluator.go +++ b/optimizer/evaluator/evaluator.go @@ -956,6 +956,8 @@ func parseDayInterval(value interface{}) (int64, error) { func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { name := strings.ToLower(v.F) switch name { + case ast.AggFuncAvg: + e.evalAggAvg(v) case ast.AggFuncCount: e.evalAggCount(v) case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncGroupConcat: @@ -973,3 +975,14 @@ func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) { ctx := v.GetContext() v.SetValue(ctx.Value) } + +func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) { + ctx := v.GetContext() + switch x := ctx.Value.(type) { + case float64: + ctx.Value = x / float64(ctx.Count) + case mysql.Decimal: + ctx.Value = x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0)) + } + v.SetValue(ctx.Value) +} diff --git a/optimizer/evaluator/evaluator_test.go b/optimizer/evaluator/evaluator_test.go index 573d6ad376..097315ff98 100644 --- a/optimizer/evaluator/evaluator_test.go +++ b/optimizer/evaluator/evaluator_test.go @@ -1054,3 +1054,27 @@ func (s *testEvaluatorSuite) TestColumnNameExpr(c *C) { c.Assert(err, IsNil) c.Assert(result, Equals, 2) } + +func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) { + ctx := mock.NewContext() + avg := &ast.AggregateFuncExpr{ + F: ast.AggFuncAvg, + } + avg.CurrentGroup = "emptyGroup" + result, err := Eval(ctx, avg) + c.Assert(err, IsNil) + // Empty group should return nil. + c.Assert(result, IsNil) + + avg.Args = []ast.ExprNode{ast.NewValueExpr(2)} + avg.Update() + avg.Args = []ast.ExprNode{ast.NewValueExpr(4)} + avg.Update() + + result, err = Eval(ctx, avg) + c.Assert(err, IsNil) + expect, _ := mysql.ConvertToDecimal(3) + v, ok := result.(mysql.Decimal) + c.Assert(ok, IsTrue) + c.Assert(v.Equals(expect), IsTrue) +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index b06748983a..c803b62abf 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -65,18 +65,10 @@ type supportChecker struct { } func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) { - switch ti := in.(type) { + switch x := in.(type) { case *ast.SubqueryExpr: c.unsupported = true - case *ast.AggregateFuncExpr: - fn := strings.ToLower(ti.F) - switch fn { - case ast.AggFuncCount, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncGroupConcat: - default: - c.unsupported = true - } case *ast.Join: - x := in.(*ast.Join) if x.Right != nil { c.unsupported = true } else { @@ -93,7 +85,6 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) { } } case *ast.SelectStmt: - x := in.(*ast.SelectStmt) if x.Distinct { c.unsupported = true } @@ -123,25 +114,28 @@ func IsSupported(node ast.Node) bool { // Optimizer error codes. const ( - CodeOneColumn terror.ErrCode = 1 - CodeSameColumns = 2 - CodeMultiWildCard = 3 - CodeUnsupported = 4 + CodeOneColumn terror.ErrCode = 1 + CodeSameColumns terror.ErrCode = 2 + CodeMultiWildCard terror.ErrCode = 3 + CodeUnsupported terror.ErrCode = 4 + CodeInvalidGroupFuncUse terror.ErrCode = 5 ) // Optimizer base errors. var ( - ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)") - ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns") - ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once") - ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported") + ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)") + ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns") + ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once") + ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported") + ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(CodeInvalidGroupFuncUse, "Invalid use of group function") ) func init() { mySQLErrCodes := map[terror.ErrCode]uint16{ - CodeOneColumn: mysql.ErrOperandColumns, - CodeSameColumns: mysql.ErrOperandColumns, - CodeMultiWildCard: mysql.ErrParse, + CodeOneColumn: mysql.ErrOperandColumns, + CodeSameColumns: mysql.ErrOperandColumns, + CodeMultiWildCard: mysql.ErrParse, + CodeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mySQLErrCodes } diff --git a/optimizer/typeinferer.go b/optimizer/typeinferer.go index 76d0479760..9e6a4bfd52 100644 --- a/optimizer/typeinferer.go +++ b/optimizer/typeinferer.go @@ -135,7 +135,7 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) { x.SetType(ft) case ast.AggFuncMax, ast.AggFuncMin: x.SetType(x.Args[0].GetType()) - case ast.AggFuncSum: + case ast.AggFuncSum, ast.AggFuncAvg: ft := types.NewFieldType(mysql.TypeNewDecimal) ft.Charset = charset.CharsetBin ft.Collate = charset.CollationBin diff --git a/optimizer/validator.go b/optimizer/validator.go index ea9f208cc3..39ec3e3105 100644 --- a/optimizer/validator.go +++ b/optimizer/validator.go @@ -32,14 +32,26 @@ type validator struct { err error wildCardCount int inPrepare bool + inAggregate bool } func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch in.(type) { + case *ast.AggregateFuncExpr: + if v.inAggregate { + // Aggregate function can not contain aggregate function. + v.err = ErrInvalidGroupFuncUse + return in, true + } + v.inAggregate = true + } return in, false } func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { switch x := in.(type) { + case *ast.AggregateFuncExpr: + v.inAggregate = false case *ast.BetweenExpr: v.checkAllOneColumn(x.Expr, x.Left, x.Right) case *ast.BinaryOperationExpr: diff --git a/session.go b/session.go index 98f065babb..c01266fa6f 100644 --- a/session.go +++ b/session.go @@ -554,6 +554,18 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool { name := strs[0] host := strs[1] pwd, err := s.getPassword(name, host) + if err != nil { + if terror.ExecResultIsEmpty.Equal(err) { + log.Errorf("User [%s] not exist %v", name, err) + } else { + log.Errorf("Get User [%s] password from SystemDB error %v", name, err) + } + return false + } + if len(pwd) != 0 && len(pwd) != 40 { + log.Errorf("User [%s] password from SystemDB not like a sha1sum", name) + return false + } hpwd, err := util.DecodePassword(pwd) if err != nil { log.Errorf("Decode password string error %v", err) diff --git a/session_test.go b/session_test.go index 46002e74b9..ef941da48e 100644 --- a/session_test.go +++ b/session_test.go @@ -1326,6 +1326,13 @@ func (s *testSessionSuite) TestSession(c *C) { c.Assert(err, IsNil) } +func (s *testSessionSuite) TestSessionAuth(c *C) { + store := newStore(c, s.dbName) + se := newSession(c, store, s.dbName) + defer se.Close() + c.Assert(se.Auth("Any not exist username with zero password! @anyhost", []byte(""), []byte("")), IsFalse) +} + func (s *testSessionSuite) TestErrorRollback(c *C) { store := newStore(c, s.dbName) s1 := newSession(c, store, s.dbName)