From 465a96270f7b846cb118eb3f74373de5c2ea9a2f Mon Sep 17 00:00:00 2001 From: Gogs Date: Wed, 20 Jan 2016 23:01:08 +0800 Subject: [PATCH 1/5] fix session auth bug --- session.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/session.go b/session.go index 98f065babb..6348cd2965 100644 --- a/session.go +++ b/session.go @@ -554,6 +554,19 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool { name := strs[0] host := strs[1] pwd, err := s.getPassword(name, host) + if err != nil { + if terror.ExecResultIsEmpty.Equal(err) { + log.Errorf("User [%s] not exist %v", name, err) + } else { + log.Errorf("Get User [%s] password from SystemDB error %v", name, err) + } + return false + } else { + if len(pwd) != 0 && len(pwd) != 40 { + log.Errorf("User [%s] password from SystemDB not like a sha1sum", name) + return false + } + } hpwd, err := util.DecodePassword(pwd) if err != nil { log.Errorf("Decode password string error %v", err) From 746153cf0d9c7403bb65c87918552e375d114837 Mon Sep 17 00:00:00 2001 From: li Date: Thu, 21 Jan 2016 00:27:48 +0800 Subject: [PATCH 2/5] golint it --- session.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/session.go b/session.go index 6348cd2965..c01266fa6f 100644 --- a/session.go +++ b/session.go @@ -561,11 +561,10 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool { log.Errorf("Get User [%s] password from SystemDB error %v", name, err) } return false - } else { - if len(pwd) != 0 && len(pwd) != 40 { - log.Errorf("User [%s] password from SystemDB not like a sha1sum", name) - return false - } + } + if len(pwd) != 0 && len(pwd) != 40 { + log.Errorf("User [%s] password from SystemDB not like a sha1sum", name) + return false } hpwd, err := util.DecodePassword(pwd) if err != nil { From 8b1bed8e699adc54a6f9101313d6f7f1a8f352cd Mon Sep 17 00:00:00 2001 From: shenli Date: Thu, 21 Jan 2016 17:34:24 +0800 Subject: [PATCH 3/5] *: Support avg in new plan --- ast/functions.go | 3 ++- optimizer/evaluator/evaluator.go | 13 +++++++++++++ optimizer/optimizer.go | 27 +++++++++++++++------------ optimizer/typeinferer.go | 2 +- optimizer/validator.go | 12 ++++++++++++ 5 files changed, 43 insertions(+), 14 deletions(-) diff --git a/ast/functions.go b/ast/functions.go index 64c10a4298..f054365b8d 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -421,7 +421,7 @@ func (n *AggregateFuncExpr) Update() error { return n.updateMaxMin(true) case AggFuncMin: return n.updateMaxMin(false) - case AggFuncSum: + case AggFuncSum, AggFuncAvg: return n.updateSum() } return nil @@ -528,6 +528,7 @@ func (n *AggregateFuncExpr) updateSum() error { if err != nil { return errors.Trace(err) } + ctx.Count++ return nil } diff --git a/optimizer/evaluator/evaluator.go b/optimizer/evaluator/evaluator.go index e78d32c96f..74bbea7068 100644 --- a/optimizer/evaluator/evaluator.go +++ b/optimizer/evaluator/evaluator.go @@ -956,6 +956,8 @@ func parseDayInterval(value interface{}) (int64, error) { func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool { name := strings.ToLower(v.F) switch name { + case ast.AggFuncAvg: + e.evalAggAvg(v) case ast.AggFuncCount: e.evalAggCount(v) case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum: @@ -973,3 +975,14 @@ func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) { ctx := v.GetContext() v.SetValue(ctx.Value) } + +func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) { + ctx := v.GetContext() + switch x := ctx.Value.(type) { + case float64: + ctx.Value = x / float64(ctx.Count) + case mysql.Decimal: + ctx.Value = x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0)) + } + v.SetValue(ctx.Value) +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index dbb38c2e7c..3c53fea7e7 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -71,7 +71,7 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) { case *ast.AggregateFuncExpr: fn := strings.ToLower(ti.F) switch fn { - case ast.AggFuncCount, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum: + case ast.AggFuncCount, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncAvg: default: c.unsupported = true } @@ -123,25 +123,28 @@ func IsSupported(node ast.Node) bool { // Optimizer error codes. const ( - CodeOneColumn terror.ErrCode = 1 - CodeSameColumns = 2 - CodeMultiWildCard = 3 - CodeUnsupported = 4 + CodeOneColumn terror.ErrCode = 1 + CodeSameColumns terror.ErrCode = 2 + CodeMultiWildCard terror.ErrCode = 3 + CodeUnsupported terror.ErrCode = 4 + CodeInvalidGroupFuncUse terror.ErrCode = 5 ) // Optimizer base errors. var ( - ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)") - ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns") - ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once") - ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported") + ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)") + ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns") + ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once") + ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported") + ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(CodeInvalidGroupFuncUse, "Invalid use of group function") ) func init() { mySQLErrCodes := map[terror.ErrCode]uint16{ - CodeOneColumn: mysql.ErrOperandColumns, - CodeSameColumns: mysql.ErrOperandColumns, - CodeMultiWildCard: mysql.ErrParse, + CodeOneColumn: mysql.ErrOperandColumns, + CodeSameColumns: mysql.ErrOperandColumns, + CodeMultiWildCard: mysql.ErrParse, + CodeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse, } terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mySQLErrCodes } diff --git a/optimizer/typeinferer.go b/optimizer/typeinferer.go index 9f6c3fa590..7b5ce77ad3 100644 --- a/optimizer/typeinferer.go +++ b/optimizer/typeinferer.go @@ -135,7 +135,7 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) { x.SetType(ft) case ast.AggFuncMax, ast.AggFuncMin: x.SetType(x.Args[0].GetType()) - case ast.AggFuncSum: + case ast.AggFuncSum, ast.AggFuncAvg: ft := types.NewFieldType(mysql.TypeNewDecimal) ft.Charset = charset.CharsetBin ft.Collate = charset.CollationBin diff --git a/optimizer/validator.go b/optimizer/validator.go index ea9f208cc3..39ec3e3105 100644 --- a/optimizer/validator.go +++ b/optimizer/validator.go @@ -32,14 +32,26 @@ type validator struct { err error wildCardCount int inPrepare bool + inAggregate bool } func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { + switch in.(type) { + case *ast.AggregateFuncExpr: + if v.inAggregate { + // Aggregate function can not contain aggregate function. + v.err = ErrInvalidGroupFuncUse + return in, true + } + v.inAggregate = true + } return in, false } func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) { switch x := in.(type) { + case *ast.AggregateFuncExpr: + v.inAggregate = false case *ast.BetweenExpr: v.checkAllOneColumn(x.Expr, x.Left, x.Right) case *ast.BinaryOperationExpr: From 46a0d6a671fee86572e9aac0fe62ac74b54b5776 Mon Sep 17 00:00:00 2001 From: Yuwen Shen Date: Fri, 22 Jan 2016 03:39:26 +0800 Subject: [PATCH 4/5] add Auth testing case. --- session_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/session_test.go b/session_test.go index 46002e74b9..ef941da48e 100644 --- a/session_test.go +++ b/session_test.go @@ -1326,6 +1326,13 @@ func (s *testSessionSuite) TestSession(c *C) { c.Assert(err, IsNil) } +func (s *testSessionSuite) TestSessionAuth(c *C) { + store := newStore(c, s.dbName) + se := newSession(c, store, s.dbName) + defer se.Close() + c.Assert(se.Auth("Any not exist username with zero password! @anyhost", []byte(""), []byte("")), IsFalse) +} + func (s *testSessionSuite) TestErrorRollback(c *C) { store := newStore(c, s.dbName) s1 := newSession(c, store, s.dbName) From 3e8c167db434f17ecf95caf48d15dd63cffa7220 Mon Sep 17 00:00:00 2001 From: shenli Date: Fri, 22 Jan 2016 11:17:41 +0800 Subject: [PATCH 5/5] *: Add unit test for evalAggAvg --- optimizer/evaluator/evaluator_test.go | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/optimizer/evaluator/evaluator_test.go b/optimizer/evaluator/evaluator_test.go index 573d6ad376..097315ff98 100644 --- a/optimizer/evaluator/evaluator_test.go +++ b/optimizer/evaluator/evaluator_test.go @@ -1054,3 +1054,27 @@ func (s *testEvaluatorSuite) TestColumnNameExpr(c *C) { c.Assert(err, IsNil) c.Assert(result, Equals, 2) } + +func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) { + ctx := mock.NewContext() + avg := &ast.AggregateFuncExpr{ + F: ast.AggFuncAvg, + } + avg.CurrentGroup = "emptyGroup" + result, err := Eval(ctx, avg) + c.Assert(err, IsNil) + // Empty group should return nil. + c.Assert(result, IsNil) + + avg.Args = []ast.ExprNode{ast.NewValueExpr(2)} + avg.Update() + avg.Args = []ast.ExprNode{ast.NewValueExpr(4)} + avg.Update() + + result, err = Eval(ctx, avg) + c.Assert(err, IsNil) + expect, _ := mysql.ConvertToDecimal(3) + v, ok := result.(mysql.Decimal) + c.Assert(ok, IsTrue) + c.Assert(v.Equals(expect), IsTrue) +}