Merge remote-tracking branch 'origin/master' into shenli/group-concat
This commit is contained in:
@ -425,7 +425,7 @@ func (n *AggregateFuncExpr) Update() error {
|
||||
return n.updateMaxMin(true)
|
||||
case AggFuncMin:
|
||||
return n.updateMaxMin(false)
|
||||
case AggFuncSum:
|
||||
case AggFuncSum, AggFuncAvg:
|
||||
return n.updateSum()
|
||||
case AggFuncGroupConcat:
|
||||
return n.updateGroupConcat()
|
||||
@ -534,6 +534,7 @@ func (n *AggregateFuncExpr) updateSum() error {
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
ctx.Count++
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -956,6 +956,8 @@ func parseDayInterval(value interface{}) (int64, error) {
|
||||
func (e *Evaluator) aggregateFunc(v *ast.AggregateFuncExpr) bool {
|
||||
name := strings.ToLower(v.F)
|
||||
switch name {
|
||||
case ast.AggFuncAvg:
|
||||
e.evalAggAvg(v)
|
||||
case ast.AggFuncCount:
|
||||
e.evalAggCount(v)
|
||||
case ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncGroupConcat:
|
||||
@ -973,3 +975,14 @@ func (e *Evaluator) evalAggSetValue(v *ast.AggregateFuncExpr) {
|
||||
ctx := v.GetContext()
|
||||
v.SetValue(ctx.Value)
|
||||
}
|
||||
|
||||
func (e *Evaluator) evalAggAvg(v *ast.AggregateFuncExpr) {
|
||||
ctx := v.GetContext()
|
||||
switch x := ctx.Value.(type) {
|
||||
case float64:
|
||||
ctx.Value = x / float64(ctx.Count)
|
||||
case mysql.Decimal:
|
||||
ctx.Value = x.Div(mysql.NewDecimalFromUint(uint64(ctx.Count), 0))
|
||||
}
|
||||
v.SetValue(ctx.Value)
|
||||
}
|
||||
|
||||
@ -1054,3 +1054,27 @@ func (s *testEvaluatorSuite) TestColumnNameExpr(c *C) {
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(result, Equals, 2)
|
||||
}
|
||||
|
||||
func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) {
|
||||
ctx := mock.NewContext()
|
||||
avg := &ast.AggregateFuncExpr{
|
||||
F: ast.AggFuncAvg,
|
||||
}
|
||||
avg.CurrentGroup = "emptyGroup"
|
||||
result, err := Eval(ctx, avg)
|
||||
c.Assert(err, IsNil)
|
||||
// Empty group should return nil.
|
||||
c.Assert(result, IsNil)
|
||||
|
||||
avg.Args = []ast.ExprNode{ast.NewValueExpr(2)}
|
||||
avg.Update()
|
||||
avg.Args = []ast.ExprNode{ast.NewValueExpr(4)}
|
||||
avg.Update()
|
||||
|
||||
result, err = Eval(ctx, avg)
|
||||
c.Assert(err, IsNil)
|
||||
expect, _ := mysql.ConvertToDecimal(3)
|
||||
v, ok := result.(mysql.Decimal)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(v.Equals(expect), IsTrue)
|
||||
}
|
||||
|
||||
@ -65,18 +65,10 @@ type supportChecker struct {
|
||||
}
|
||||
|
||||
func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) {
|
||||
switch ti := in.(type) {
|
||||
switch x := in.(type) {
|
||||
case *ast.SubqueryExpr:
|
||||
c.unsupported = true
|
||||
case *ast.AggregateFuncExpr:
|
||||
fn := strings.ToLower(ti.F)
|
||||
switch fn {
|
||||
case ast.AggFuncCount, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncSum, ast.AggFuncGroupConcat:
|
||||
default:
|
||||
c.unsupported = true
|
||||
}
|
||||
case *ast.Join:
|
||||
x := in.(*ast.Join)
|
||||
if x.Right != nil {
|
||||
c.unsupported = true
|
||||
} else {
|
||||
@ -93,7 +85,6 @@ func (c *supportChecker) Enter(in ast.Node) (ast.Node, bool) {
|
||||
}
|
||||
}
|
||||
case *ast.SelectStmt:
|
||||
x := in.(*ast.SelectStmt)
|
||||
if x.Distinct {
|
||||
c.unsupported = true
|
||||
}
|
||||
@ -123,25 +114,28 @@ func IsSupported(node ast.Node) bool {
|
||||
|
||||
// Optimizer error codes.
|
||||
const (
|
||||
CodeOneColumn terror.ErrCode = 1
|
||||
CodeSameColumns = 2
|
||||
CodeMultiWildCard = 3
|
||||
CodeUnsupported = 4
|
||||
CodeOneColumn terror.ErrCode = 1
|
||||
CodeSameColumns terror.ErrCode = 2
|
||||
CodeMultiWildCard terror.ErrCode = 3
|
||||
CodeUnsupported terror.ErrCode = 4
|
||||
CodeInvalidGroupFuncUse terror.ErrCode = 5
|
||||
)
|
||||
|
||||
// Optimizer base errors.
|
||||
var (
|
||||
ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)")
|
||||
ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns")
|
||||
ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once")
|
||||
ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported")
|
||||
ErrOneColumn = terror.ClassOptimizer.New(CodeOneColumn, "Operand should contain 1 column(s)")
|
||||
ErrSameColumns = terror.ClassOptimizer.New(CodeSameColumns, "Operands should contain same columns")
|
||||
ErrMultiWildCard = terror.ClassOptimizer.New(CodeMultiWildCard, "wildcard field exist more than once")
|
||||
ErrUnSupported = terror.ClassOptimizer.New(CodeUnsupported, "unsupported")
|
||||
ErrInvalidGroupFuncUse = terror.ClassOptimizer.New(CodeInvalidGroupFuncUse, "Invalid use of group function")
|
||||
)
|
||||
|
||||
func init() {
|
||||
mySQLErrCodes := map[terror.ErrCode]uint16{
|
||||
CodeOneColumn: mysql.ErrOperandColumns,
|
||||
CodeSameColumns: mysql.ErrOperandColumns,
|
||||
CodeMultiWildCard: mysql.ErrParse,
|
||||
CodeOneColumn: mysql.ErrOperandColumns,
|
||||
CodeSameColumns: mysql.ErrOperandColumns,
|
||||
CodeMultiWildCard: mysql.ErrParse,
|
||||
CodeInvalidGroupFuncUse: mysql.ErrInvalidGroupFuncUse,
|
||||
}
|
||||
terror.ErrClassToMySQLCodes[terror.ClassOptimizer] = mySQLErrCodes
|
||||
}
|
||||
|
||||
@ -135,7 +135,7 @@ func (v *typeInferrer) aggregateFunc(x *ast.AggregateFuncExpr) {
|
||||
x.SetType(ft)
|
||||
case ast.AggFuncMax, ast.AggFuncMin:
|
||||
x.SetType(x.Args[0].GetType())
|
||||
case ast.AggFuncSum:
|
||||
case ast.AggFuncSum, ast.AggFuncAvg:
|
||||
ft := types.NewFieldType(mysql.TypeNewDecimal)
|
||||
ft.Charset = charset.CharsetBin
|
||||
ft.Collate = charset.CollationBin
|
||||
|
||||
@ -32,14 +32,26 @@ type validator struct {
|
||||
err error
|
||||
wildCardCount int
|
||||
inPrepare bool
|
||||
inAggregate bool
|
||||
}
|
||||
|
||||
func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) {
|
||||
switch in.(type) {
|
||||
case *ast.AggregateFuncExpr:
|
||||
if v.inAggregate {
|
||||
// Aggregate function can not contain aggregate function.
|
||||
v.err = ErrInvalidGroupFuncUse
|
||||
return in, true
|
||||
}
|
||||
v.inAggregate = true
|
||||
}
|
||||
return in, false
|
||||
}
|
||||
|
||||
func (v *validator) Leave(in ast.Node) (out ast.Node, ok bool) {
|
||||
switch x := in.(type) {
|
||||
case *ast.AggregateFuncExpr:
|
||||
v.inAggregate = false
|
||||
case *ast.BetweenExpr:
|
||||
v.checkAllOneColumn(x.Expr, x.Left, x.Right)
|
||||
case *ast.BinaryOperationExpr:
|
||||
|
||||
12
session.go
12
session.go
@ -554,6 +554,18 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool {
|
||||
name := strs[0]
|
||||
host := strs[1]
|
||||
pwd, err := s.getPassword(name, host)
|
||||
if err != nil {
|
||||
if terror.ExecResultIsEmpty.Equal(err) {
|
||||
log.Errorf("User [%s] not exist %v", name, err)
|
||||
} else {
|
||||
log.Errorf("Get User [%s] password from SystemDB error %v", name, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
if len(pwd) != 0 && len(pwd) != 40 {
|
||||
log.Errorf("User [%s] password from SystemDB not like a sha1sum", name)
|
||||
return false
|
||||
}
|
||||
hpwd, err := util.DecodePassword(pwd)
|
||||
if err != nil {
|
||||
log.Errorf("Decode password string error %v", err)
|
||||
|
||||
@ -1326,6 +1326,13 @@ func (s *testSessionSuite) TestSession(c *C) {
|
||||
c.Assert(err, IsNil)
|
||||
}
|
||||
|
||||
func (s *testSessionSuite) TestSessionAuth(c *C) {
|
||||
store := newStore(c, s.dbName)
|
||||
se := newSession(c, store, s.dbName)
|
||||
defer se.Close()
|
||||
c.Assert(se.Auth("Any not exist username with zero password! @anyhost", []byte(""), []byte("")), IsFalse)
|
||||
}
|
||||
|
||||
func (s *testSessionSuite) TestErrorRollback(c *C) {
|
||||
store := newStore(c, s.dbName)
|
||||
s1 := newSession(c, store, s.dbName)
|
||||
|
||||
Reference in New Issue
Block a user