diff --git a/executor/explain_test.go b/executor/explain_test.go index 9b7f9df9e1..eafce87778 100644 --- a/executor/explain_test.go +++ b/executor/explain_test.go @@ -280,10 +280,10 @@ func (s *testSuite) TestExplain(c *C) { { "select count(b.c2) from t1 a, t2 b where a.c1 = b.c2 group by a.c1", []string{ - "TableScan_10", "TableScan_11", "HashAgg_12", "HashLeftJoin_9", "HashAgg_17", + "TableScan_11", "TableScan_12", "HashAgg_13", "HashLeftJoin_10", "Projection_9", }, []string{ - "HashLeftJoin_9", "HashAgg_12", "HashLeftJoin_9", "HashAgg_17", "", + "HashLeftJoin_10", "HashAgg_13", "HashLeftJoin_10", "Projection_9", "", }, []string{`{ "db": "test", @@ -325,7 +325,7 @@ func (s *testSuite) TestExplain(c *C) { "GroupByItems": [ "[b.c2]" ], - "child": "TableScan_11" + "child": "TableScan_12" }`, `{ "eqCond": [ @@ -334,17 +334,14 @@ func (s *testSuite) TestExplain(c *C) { "leftCond": null, "rightCond": null, "otherCond": null, - "leftPlan": "TableScan_10", - "rightPlan": "HashAgg_12" + "leftPlan": "TableScan_11", + "rightPlan": "HashAgg_13" }`, `{ - "AggFuncs": [ - "count(join_agg_0)" + "exprs": [ + "cast(join_agg_0)" ], - "GroupByItems": [ - "a.c1" - ], - "child": "HashLeftJoin_9" + "child": "HashLeftJoin_10" }`, }, }, diff --git a/plan/aggregation_pruning.go b/plan/aggregation_pruning.go deleted file mode 100644 index 0944c54817..0000000000 --- a/plan/aggregation_pruning.go +++ /dev/null @@ -1,116 +0,0 @@ -// Copyright 2017 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// // Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package plan - -import ( - "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/mysql" - "github.com/pingcap/tidb/util/types" -) - -type aggPruner struct { - allocator *idAllocator - ctx context.Context -} - -func (ap *aggPruner) optimize(lp LogicalPlan, ctx context.Context, allocator *idAllocator) (LogicalPlan, error) { - ap.ctx = ctx - ap.allocator = allocator - return ap.eliminateAggregation(lp), nil -} - -// eliminateAggregation will eliminate aggregation grouped by unique key. -// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. -// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. -func (ap *aggPruner) eliminateAggregation(p LogicalPlan) LogicalPlan { - retPlan := p - if agg, ok := p.(*Aggregation); ok { - schemaByGroupby := expression.NewSchema(agg.groupByCols...) - coveredByUniqueKey := false - for _, key := range agg.children[0].Schema().Keys { - if schemaByGroupby.ColumnsIndices(key) != nil { - coveredByUniqueKey = true - break - } - } - if coveredByUniqueKey { - // GroupByCols has unique key, so this aggregation can be removed. - proj := convertAggToProj(agg, ap.ctx, ap.allocator) - proj.SetParents(p.Parents()...) - for _, child := range p.Children() { - child.SetParents(proj) - } - retPlan = proj - } - } - newChildren := make([]Plan, 0, len(p.Children())) - for _, child := range p.Children() { - newChild := ap.eliminateAggregation(child.(LogicalPlan)) - newChildren = append(newChildren, newChild) - } - retPlan.SetChildren(newChildren...) - return retPlan -} - -func convertAggToProj(agg *Aggregation, ctx context.Context, allocator *idAllocator) *Projection { - proj := &Projection{ - Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), - baseLogicalPlan: newBaseLogicalPlan(Proj, allocator), - } - proj.self = proj - proj.initIDAndContext(ctx) - for _, fun := range agg.AggFuncs { - expr := rewriteExpr(fun.GetArgs(), fun.GetName(), ctx) - proj.Exprs = append(proj.Exprs, expr) - } - proj.SetSchema(agg.schema.Clone()) - return proj -} - -// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. -func rewriteExpr(exprs []expression.Expression, funcName string, ctx context.Context) (newExpr expression.Expression) { - switch funcName { - case ast.AggFuncCount: - // If is count(expr), we will change it to if(isnull(expr), 0, 1). - // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). - isNullExprs := make([]expression.Expression, 0, len(exprs)) - for _, expr := range exprs { - isNullExpr, _ := expression.NewFunction(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr.Clone()) - isNullExprs = append(isNullExprs, isNullExpr) - } - innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...) - newExpr, _ = expression.NewFunction(ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One) - // See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html - // The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL), - // and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE). - case ast.AggFuncSum, ast.AggFuncAvg: - expr := exprs[0].Clone() - switch expr.GetType().Tp { - // Integer type should be cast to decimal. - case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: - newExpr = expression.NewCastFunc(types.NewFieldType(mysql.TypeNewDecimal), expr, ctx) - // Double and Decimal doesn't need to be cast. - case mysql.TypeDouble, mysql.TypeNewDecimal: - newExpr = expr - // Float should be cast to double. And other non-numeric type should be cast to double too. - default: - newExpr = expression.NewCastFunc(types.NewFieldType(mysql.TypeDouble), expr, ctx) - } - default: - // Default we do nothing about expr. - newExpr = exprs[0].Clone() - } - return -} diff --git a/plan/aggregation_push_down.go b/plan/aggregation_push_down.go index 3deaf7d69f..922222f3a6 100644 --- a/plan/aggregation_push_down.go +++ b/plan/aggregation_push_down.go @@ -23,9 +23,9 @@ import ( "github.com/pingcap/tidb/util/types" ) -type aggPushDownSolver struct { - alloc *idAllocator - ctx context.Context +type aggregationOptimizer struct { + allocator *idAllocator + ctx context.Context } // isDecomposable checks if an aggregate function is decomposable. An aggregation function $F$ is decomposable @@ -34,7 +34,7 @@ type aggPushDownSolver struct { // It's easy to see that max, min, first row is decomposable, no matter whether it's distinct, but sum(distinct) and // count(distinct) is not. // Currently we don't support avg and concat. -func (a *aggPushDownSolver) isDecomposable(fun expression.AggregationFunction) bool { +func (a *aggregationOptimizer) isDecomposable(fun expression.AggregationFunction) bool { switch fun.GetName() { case ast.AggFuncAvg, ast.AggFuncGroupConcat: // TODO: Support avg push down. @@ -49,7 +49,7 @@ func (a *aggPushDownSolver) isDecomposable(fun expression.AggregationFunction) b } // getAggFuncChildIdx gets which children it belongs to, 0 stands for left, 1 stands for right, -1 stands for both. -func (a *aggPushDownSolver) getAggFuncChildIdx(aggFunc expression.AggregationFunction, schema *expression.Schema) int { +func (a *aggregationOptimizer) getAggFuncChildIdx(aggFunc expression.AggregationFunction, schema *expression.Schema) int { fromLeft, fromRight := false, false var cols []*expression.Column for _, arg := range aggFunc.GetArgs() { @@ -73,7 +73,7 @@ func (a *aggPushDownSolver) getAggFuncChildIdx(aggFunc expression.AggregationFun // collectAggFuncs collects all aggregate functions and splits them into two parts: "leftAggFuncs" and "rightAggFuncs" whose // arguments are all from left child or right child separately. If some aggregate functions have the arguments that have // columns both from left and right children, the whole aggregation is forbidden to push down. -func (a *aggPushDownSolver) collectAggFuncs(agg *Aggregation, join *Join) (valid bool, leftAggFuncs, rightAggFuncs []expression.AggregationFunction) { +func (a *aggregationOptimizer) collectAggFuncs(agg *Aggregation, join *Join) (valid bool, leftAggFuncs, rightAggFuncs []expression.AggregationFunction) { valid = true leftChild := join.children[0] for _, aggFunc := range agg.AggFuncs { @@ -98,7 +98,7 @@ func (a *aggPushDownSolver) collectAggFuncs(agg *Aggregation, join *Join) (valid // query should be "SELECT SUM(B.agg) FROM A, (SELECT SUM(id) as agg, c1, c2, c3 FROM B GROUP BY id, c1, c2, c3) as B // WHERE A.c1 = B.c1 AND A.c2 != B.c2 GROUP BY B.c3". As you see, all the columns appearing in join-conditions should be // treated as group by columns in join subquery. -func (a *aggPushDownSolver) collectGbyCols(agg *Aggregation, join *Join) (leftGbyCols, rightGbyCols []*expression.Column) { +func (a *aggregationOptimizer) collectGbyCols(agg *Aggregation, join *Join) (leftGbyCols, rightGbyCols []*expression.Column) { leftChild := join.children[0] for _, gbyExpr := range agg.GroupByItems { cols := expression.ExtractColumns(gbyExpr) @@ -136,7 +136,7 @@ func (a *aggPushDownSolver) collectGbyCols(agg *Aggregation, join *Join) (leftGb return } -func (a *aggPushDownSolver) splitAggFuncsAndGbyCols(agg *Aggregation, join *Join) (valid bool, +func (a *aggregationOptimizer) splitAggFuncsAndGbyCols(agg *Aggregation, join *Join) (valid bool, leftAggFuncs, rightAggFuncs []expression.AggregationFunction, leftGbyCols, rightGbyCols []*expression.Column) { valid, leftAggFuncs, rightAggFuncs = a.collectAggFuncs(agg, join) @@ -148,7 +148,7 @@ func (a *aggPushDownSolver) splitAggFuncsAndGbyCols(agg *Aggregation, join *Join } // addGbyCol adds a column to gbyCols. If a group by column has existed, it will not be added repeatedly. -func (a *aggPushDownSolver) addGbyCol(gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column { +func (a *aggregationOptimizer) addGbyCol(gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column { for _, c := range cols { duplicate := false for _, gbyCol := range gbyCols { @@ -165,13 +165,13 @@ func (a *aggPushDownSolver) addGbyCol(gbyCols []*expression.Column, cols ...*exp } // checkValidJoin checks if this join should be pushed across. -func (a *aggPushDownSolver) checkValidJoin(join *Join) bool { +func (a *aggregationOptimizer) checkValidJoin(join *Join) bool { return join.JoinType == InnerJoin || join.JoinType == LeftOuterJoin || join.JoinType == RightOuterJoin } // decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently // there are no differences between partial mode and complete mode, so we can confuse them. -func (a *aggPushDownSolver) decompose(aggFunc expression.AggregationFunction, schema *expression.Schema, id string) ([]expression.AggregationFunction, *expression.Schema) { +func (a *aggregationOptimizer) decompose(aggFunc expression.AggregationFunction, schema *expression.Schema, id string) ([]expression.AggregationFunction, *expression.Schema) { // Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case. result := []expression.AggregationFunction{aggFunc.Clone()} for _, aggFunc := range result { @@ -187,7 +187,7 @@ func (a *aggPushDownSolver) decompose(aggFunc expression.AggregationFunction, sc return result, schema } -func (a *aggPushDownSolver) allFirstRow(aggFuncs []expression.AggregationFunction) bool { +func (a *aggregationOptimizer) allFirstRow(aggFuncs []expression.AggregationFunction) bool { for _, fun := range aggFuncs { if fun.GetName() != ast.AggFuncFirstRow { return false @@ -199,7 +199,7 @@ func (a *aggPushDownSolver) allFirstRow(aggFuncs []expression.AggregationFunctio // tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't // process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator. // If the pushed aggregation is grouped by unique key, it's no need to push it down. -func (a *aggPushDownSolver) tryToPushDownAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column, join *Join, childIdx int) LogicalPlan { +func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column, join *Join, childIdx int) LogicalPlan { child := join.children[childIdx].(LogicalPlan) if a.allFirstRow(aggFuncs) { return child @@ -230,7 +230,7 @@ func (a *aggPushDownSolver) tryToPushDownAgg(aggFuncs []expression.AggregationFu return agg } -func (a *aggPushDownSolver) getDefaultValues(agg *Aggregation) ([]types.Datum, bool) { +func (a *aggregationOptimizer) getDefaultValues(agg *Aggregation) ([]types.Datum, bool) { defaultValues := make([]types.Datum, 0, agg.Schema().Len()) for _, aggFunc := range agg.AggFuncs { value, existsDefaultValue := aggFunc.CalculateDefaultValue(agg.children[0].Schema(), a.ctx) @@ -242,7 +242,7 @@ func (a *aggPushDownSolver) getDefaultValues(agg *Aggregation) ([]types.Datum, b return defaultValues, true } -func (a *aggPushDownSolver) checkAnyCountAndSum(aggFuncs []expression.AggregationFunction) bool { +func (a *aggregationOptimizer) checkAnyCountAndSum(aggFuncs []expression.AggregationFunction) bool { for _, fun := range aggFuncs { if fun.GetName() == ast.AggFuncSum || fun.GetName() == ast.AggFuncCount { return true @@ -251,10 +251,10 @@ func (a *aggPushDownSolver) checkAnyCountAndSum(aggFuncs []expression.Aggregatio return false } -func (a *aggPushDownSolver) makeNewAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column) *Aggregation { +func (a *aggregationOptimizer) makeNewAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column) *Aggregation { agg := &Aggregation{ GroupByItems: expression.Column2Exprs(gbyCols), - baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc), + baseLogicalPlan: newBaseLogicalPlan(Agg, a.allocator), groupByCols: gbyCols, } agg.initIDAndContext(a.ctx) @@ -277,11 +277,11 @@ func (a *aggPushDownSolver) makeNewAgg(aggFuncs []expression.AggregationFunction // pushAggCrossUnion will try to push the agg down to the union. If the new aggregation's group-by columns doesn't contain unique key. // We will return the new aggregation. Otherwise we will transform the aggregation to projection. -func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan { +func (a *aggregationOptimizer) pushAggCrossUnion(agg *Aggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan { newAgg := &Aggregation{ AggFuncs: make([]expression.AggregationFunction, 0, len(agg.AggFuncs)), GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)), - baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc), + baseLogicalPlan: newBaseLogicalPlan(Agg, a.allocator), } newAgg.SetSchema(agg.schema.Clone()) newAgg.initIDAndContext(a.ctx) @@ -305,7 +305,7 @@ func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema *exp // this will cause error during executor phase. for _, key := range unionChild.Schema().Keys { if tmpSchema.ColumnsIndices(key) != nil { - proj := convertAggToProj(newAgg, a.ctx, a.alloc) + proj := a.convertAggToProj(newAgg, a.ctx, a.allocator) proj.SetChildren(unionChild) return proj } @@ -315,72 +315,173 @@ func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema *exp return newAgg } -func (a *aggPushDownSolver) optimize(p LogicalPlan, ctx context.Context, alloc *idAllocator) (LogicalPlan, error) { +func (a *aggregationOptimizer) optimize(p LogicalPlan, ctx context.Context, alloc *idAllocator) (LogicalPlan, error) { if !ctx.GetSessionVars().AllowAggPushDown { return p, nil } a.ctx = ctx - a.alloc = alloc + a.allocator = alloc a.aggPushDown(p) return p, nil } // aggPushDown tries to push down aggregate functions to join paths. -func (a *aggPushDownSolver) aggPushDown(p LogicalPlan) { +func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan { if agg, ok := p.(*Aggregation); ok { - child := agg.children[0] - if join, ok1 := child.(*Join); ok1 && a.checkValidJoin(join) { - if valid, leftAggFuncs, rightAggFuncs, leftGbyCols, rightGbyCols := a.splitAggFuncsAndGbyCols(agg, join); valid { - var lChild, rChild LogicalPlan - // If there exist count or sum functions in left join path, we can't push any - // aggregate function into right join path. - rightInvalid := a.checkAnyCountAndSum(leftAggFuncs) - leftInvalid := a.checkAnyCountAndSum(rightAggFuncs) - if rightInvalid { - rChild = join.children[1].(LogicalPlan) - } else { - rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + proj := a.tryToEliminateAggregation(agg) + if proj != nil { + p = proj + } else { + child := agg.children[0] + if join, ok1 := child.(*Join); ok1 && a.checkValidJoin(join) { + if valid, leftAggFuncs, rightAggFuncs, leftGbyCols, rightGbyCols := a.splitAggFuncsAndGbyCols(agg, join); valid { + var lChild, rChild LogicalPlan + // If there exist count or sum functions in left join path, we can't push any + // aggregate function into right join path. + rightInvalid := a.checkAnyCountAndSum(leftAggFuncs) + leftInvalid := a.checkAnyCountAndSum(rightAggFuncs) + if rightInvalid { + rChild = join.children[1].(LogicalPlan) + } else { + rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1) + } + if leftInvalid { + lChild = join.children[0].(LogicalPlan) + } else { + lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + } + join.SetChildren(lChild, rChild) + lChild.SetParents(join) + rChild.SetParents(join) + join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) + join.buildKeyInfo() + proj := a.tryToEliminateAggregation(agg) + if proj != nil { + p = proj + } } - if leftInvalid { - lChild = join.children[0].(LogicalPlan) - } else { - lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0) + } else if proj, ok1 := child.(*Projection); ok1 { + // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, + // so we must do this optimization. + for i, gbyItem := range agg.GroupByItems { + agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) } - join.SetChildren(lChild, rChild) - lChild.SetParents(join) - rChild.SetParents(join) - join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema())) - } - } else if proj, ok1 := child.(*Projection); ok1 { - // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, - // so we must do this optimization. - for i, gbyItem := range agg.GroupByItems { - agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) - } - agg.collectGroupByColumns() - for _, aggFunc := range agg.AggFuncs { - newArgs := make([]expression.Expression, 0, len(aggFunc.GetArgs())) - for _, arg := range aggFunc.GetArgs() { - newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + agg.collectGroupByColumns() + for _, aggFunc := range agg.AggFuncs { + newArgs := make([]expression.Expression, 0, len(aggFunc.GetArgs())) + for _, arg := range aggFunc.GetArgs() { + newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + } + aggFunc.SetArgs(newArgs) } - aggFunc.SetArgs(newArgs) + projChild := proj.children[0] + agg.SetChildren(projChild) + projChild.SetParents(agg) + } else if union, ok1 := child.(*Union); ok1 { + pushedAgg := a.makeNewAgg(agg.AggFuncs, agg.groupByCols) + newChildren := make([]Plan, 0, len(union.children)) + for _, child := range union.children { + newChild := a.pushAggCrossUnion(pushedAgg, union.schema, child.(LogicalPlan)) + newChildren = append(newChildren, newChild) + newChild.SetParents(union) + } + union.SetChildren(newChildren...) + union.SetSchema(pushedAgg.schema) } - projChild := proj.children[0] - agg.SetChildren(projChild) - projChild.SetParents(agg) - } else if union, ok1 := child.(*Union); ok1 { - pushedAgg := a.makeNewAgg(agg.AggFuncs, agg.groupByCols) - newChildren := make([]Plan, 0, len(union.children)) - for _, child := range union.children { - newChild := a.pushAggCrossUnion(pushedAgg, union.schema, child.(LogicalPlan)) - newChildren = append(newChildren, newChild) - newChild.SetParents(union) - } - union.SetChildren(newChildren...) - union.SetSchema(pushedAgg.schema) } } + newChildren := make([]Plan, 0, len(p.Children())) for _, child := range p.Children() { - a.aggPushDown(child.(LogicalPlan)) + newChild := a.aggPushDown(child.(LogicalPlan)) + newChild.SetParents(p) + newChildren = append(newChildren, newChild) + } + p.SetChildren(newChildren...) + return p +} + +// tryToEliminateAggregation will eliminate aggregation grouped by unique key. +// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`. +// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below. +// If we can eliminate agg successful, we return a projection. Else we return a nil pointer. +func (a *aggregationOptimizer) tryToEliminateAggregation(agg *Aggregation) *Projection { + schemaByGroupby := expression.NewSchema(agg.groupByCols...) + coveredByUniqueKey := false + for _, key := range agg.children[0].Schema().Keys { + if schemaByGroupby.ColumnsIndices(key) != nil { + coveredByUniqueKey = true + break + } + } + if coveredByUniqueKey { + // GroupByCols has unique key, so this aggregation can be removed. + proj := a.convertAggToProj(agg, a.ctx, a.allocator) + proj.SetChildren(agg.children[0]) + agg.children[0].SetParents(proj) + return proj + } + return nil +} + +func (a *aggregationOptimizer) convertAggToProj(agg *Aggregation, ctx context.Context, allocator *idAllocator) *Projection { + proj := &Projection{ + Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)), + baseLogicalPlan: newBaseLogicalPlan(Proj, allocator), + } + proj.self = proj + proj.initIDAndContext(ctx) + for _, fun := range agg.AggFuncs { + expr := a.rewriteExpr(fun) + proj.Exprs = append(proj.Exprs, expr) + } + proj.SetSchema(agg.schema.Clone()) + return proj +} + +func (a *aggregationOptimizer) rewriteCount(exprs []expression.Expression) expression.Expression { + // If is count(expr), we will change it to if(isnull(expr), 0, 1). + // If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1). + isNullExprs := make([]expression.Expression, 0, len(exprs)) + for _, expr := range exprs { + isNullExpr, _ := expression.NewFunction(a.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr.Clone()) + isNullExprs = append(isNullExprs, isNullExpr) + } + innerExpr := expression.ComposeDNFCondition(a.ctx, isNullExprs...) + newExpr, _ := expression.NewFunction(a.ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One) + return newExpr +} + +// See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html +// The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL), +// and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE). +func (a *aggregationOptimizer) rewriteSumOrAvg(exprs []expression.Expression) expression.Expression { + // FIXME: Consider the case that avg is final mode. + expr := exprs[0].Clone() + switch expr.GetType().Tp { + // Integer type should be cast to decimal. + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + return expression.NewCastFunc(types.NewFieldType(mysql.TypeNewDecimal), expr, a.ctx) + // Double and Decimal doesn't need to be cast. + case mysql.TypeDouble, mysql.TypeNewDecimal: + return expr + // Float should be cast to double. And other non-numeric type should be cast to double too. + default: + return expression.NewCastFunc(types.NewFieldType(mysql.TypeDouble), expr, a.ctx) + } +} + +// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function. +func (a *aggregationOptimizer) rewriteExpr(aggFunc expression.AggregationFunction) expression.Expression { + switch aggFunc.GetName() { + case ast.AggFuncCount: + if aggFunc.GetMode() == expression.FinalMode { + return a.rewriteSumOrAvg(aggFunc.GetArgs()) + } + return a.rewriteCount(aggFunc.GetArgs()) + case ast.AggFuncSum, ast.AggFuncAvg: + return a.rewriteSumOrAvg(aggFunc.GetArgs()) + default: + // Default we do nothing about expr. + return aggFunc.GetArgs()[0].Clone() } } diff --git a/plan/build_key_info.go b/plan/build_key_info.go index ade2fd176e..afcde00e89 100644 --- a/plan/build_key_info.go +++ b/plan/build_key_info.go @@ -40,6 +40,16 @@ func (p *Aggregation) buildKeyInfo() { } p.schema.Keys = append(p.schema.Keys, newKey) } + if len(p.groupByCols) == len(p.GroupByItems) && len(p.GroupByItems) > 0 { + indices := p.schema.ColumnsIndices(p.groupByCols) + if indices != nil { + newKey := make([]*expression.Column, 0, len(indices)) + for _, i := range indices { + newKey = append(newKey, p.schema.Columns[i]) + } + p.schema.Keys = append(p.schema.Keys, newKey) + } + } if len(p.GroupByItems) == 0 { p.schema.MaxOneRow = true } diff --git a/plan/decorrelate.go b/plan/decorrelate.go index 1125774264..ef40bfaa69 100644 --- a/plan/decorrelate.go +++ b/plan/decorrelate.go @@ -153,6 +153,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo np, _ := s.optimize(p, nil, nil) agg.SetChildren(np) np.SetParents(agg) + agg.collectGroupByColumns() return agg, nil } } diff --git a/plan/logical_plan_builder.go b/plan/logical_plan_builder.go index 0647f8c5f1..b7aafdd61f 100644 --- a/plan/logical_plan_builder.go +++ b/plan/logical_plan_builder.go @@ -47,8 +47,7 @@ func (p *Aggregation) collectGroupByColumns() { func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int) { b.optFlag = b.optFlag | flagBuildKeyInfo - b.optFlag = b.optFlag | flagEliminateAgg - b.optFlag = b.optFlag | flagAggPushDown + b.optFlag = b.optFlag | flagAggregationOptimize agg := &Aggregation{ AggFuncs: make([]expression.AggregationFunction, 0, len(aggFuncList)), baseLogicalPlan: newBaseLogicalPlan(Agg, b.allocator)} @@ -325,8 +324,7 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField, func (b *planBuilder) buildDistinct(child LogicalPlan, length int) LogicalPlan { b.optFlag = b.optFlag | flagBuildKeyInfo - b.optFlag = b.optFlag | flagEliminateAgg - b.optFlag = b.optFlag | flagAggPushDown + b.optFlag = b.optFlag | flagAggregationOptimize agg := &Aggregation{ baseLogicalPlan: newBaseLogicalPlan(Agg, b.allocator), AggFuncs: make([]expression.AggregationFunction, 0, child.Schema().Len()), diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index 978e121126..ebaa98b37c 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -743,7 +743,7 @@ func (s *testPlanSuite) TestAggPushDown(c *C) { }, { sql: "select max(a.b), max(b.b) from t a join t b on a.c = b.c group by a.a", - best: "Join{DataScan(a)->DataScan(b)->Aggr(max(b.b),firstrow(b.c))}(a.c,b.c)->Aggr(max(a.b),max(join_agg_0))->Projection", + best: "Join{DataScan(a)->DataScan(b)->Aggr(max(b.b),firstrow(b.c))}(a.c,b.c)->Projection->Projection", }, { sql: "select max(a.b), max(b.b) from t a join t b on a.a = b.a group by a.c", @@ -775,7 +775,7 @@ func (s *testPlanSuite) TestAggPushDown(c *C) { p := builder.build(stmt) c.Assert(builder.err, IsNil) lp := p.(LogicalPlan) - p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagAggPushDown, lp.(LogicalPlan), builder.ctx, builder.allocator) + p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagAggregationOptimize, lp.(LogicalPlan), builder.ctx, builder.allocator) lp.ResolveIndicesAndCorCols() c.Assert(err, IsNil) c.Assert(ToString(lp), Equals, ca.best, Commentf("for %s", ca.sql)) @@ -1290,7 +1290,7 @@ func (s *testPlanSuite) TestUniqueKeyInfo(c *C) { ans map[string][][]string }{ { - sql: "select a, sum(e) from t group by a", + sql: "select a, sum(e) from t group by b", ans: map[string][][]string{ "TableScan_1": {{"test.t.a"}}, "Aggregation_2": {{"test.t.a"}}, @@ -1298,23 +1298,23 @@ func (s *testPlanSuite) TestUniqueKeyInfo(c *C) { }, }, { - sql: "select a, sum(f) from t group by a", + sql: "select a, b, sum(f) from t group by b", ans: map[string][][]string{ "TableScan_1": {{"test.t.f"}, {"test.t.a"}}, - "Aggregation_2": {{"test.t.a"}}, - "Projection_3": {{"a"}}, + "Aggregation_2": {{"test.t.a"}, {"test.t.b"}}, + "Projection_3": {{"a"}, {"b"}}, }, }, { sql: "select c, d, e, sum(a) from t group by c, d, e", ans: map[string][][]string{ "TableScan_1": {{"test.t.a"}}, - "Aggregation_2": nil, - "Projection_3": nil, + "Aggregation_2": {{"test.t.c", "test.t.d", "test.t.e"}}, + "Projection_3": {{"c", "d", "e"}}, }, }, { - sql: "select f, g, sum(a) from t group by f, g", + sql: "select f, g, sum(a) from t", ans: map[string][][]string{ "TableScan_1": {{"test.t.f"}, {"test.t.g"}, {"test.t.f", "test.t.g"}, {"test.t.a"}}, "Aggregation_2": {{"test.t.f"}, {"test.t.g"}, {"test.t.f", "test.t.g"}}, @@ -1394,6 +1394,10 @@ func (s *testPlanSuite) TestAggPrune(c *C) { sql: "select tt.a, sum(tt.b) from (select a, b from t) tt group by tt.a", best: "DataScan(t)->Projection->Projection->Projection", }, + { + sql: "select count(1) from (select count(1), a as b from t group by a) tt group by b", + best: "DataScan(t)->Projection->Projection->Projection->Projection", + }, } for _, ca := range cases { comment := Commentf("for %s", ca.sql) @@ -1410,7 +1414,7 @@ func (s *testPlanSuite) TestAggPrune(c *C) { } p := builder.build(stmt).(LogicalPlan) c.Assert(builder.err, IsNil) - p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagEliminateAgg, p.(LogicalPlan), builder.ctx, builder.allocator) + p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagAggregationOptimize, p.(LogicalPlan), builder.ctx, builder.allocator) c.Assert(err, IsNil) c.Assert(ToString(p), Equals, ca.best, comment) } diff --git a/plan/optimizer.go b/plan/optimizer.go index d00a91dd9d..bbe73bb53d 100644 --- a/plan/optimizer.go +++ b/plan/optimizer.go @@ -15,6 +15,7 @@ package plan import ( "github.com/juju/errors" + "github.com/ngaut/log" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" @@ -35,8 +36,7 @@ const ( flagBuildKeyInfo flagDecorrelate flagPredicatePushDown - flagEliminateAgg - flagAggPushDown + flagAggregationOptimize ) var optRuleList = []logicalOptRule{ @@ -44,8 +44,7 @@ var optRuleList = []logicalOptRule{ &buildKeySolver{}, &decorrelateSolver{}, &ppdSolver{}, - &aggPruner{}, - &aggPushDownSolver{}, + &aggregationOptimizer{}, } // logicalOptRule means a logical optimizing rule, which contains decorrelate, ppd, column pruning, etc. @@ -104,6 +103,7 @@ func doOptimize(flag uint64, logic LogicalPlan, ctx context.Context, allocator * return nil, errors.Trace(ErrCartesianProductUnsupported) } logic.ResolveIndicesAndCorCols() + log.Warnf("PLAN %s", ToString(logic)) return physicalOptimize(flag, logic, allocator) } diff --git a/plan/physical_plan_test.go b/plan/physical_plan_test.go index 6a1c0109ba..eea745295e 100644 --- a/plan/physical_plan_test.go +++ b/plan/physical_plan_test.go @@ -85,7 +85,7 @@ func (s *testPlanSuite) TestPushDownAggregation(c *C) { _, lp, err = lp.PredicatePushDown(nil) c.Assert(err, IsNil) lp.PruneColumns(lp.Schema().Columns) - solver := &aggPushDownSolver{builder.allocator, builder.ctx} + solver := &aggregationOptimizer{builder.allocator, builder.ctx} solver.aggPushDown(lp) lp.ResolveIndicesAndCorCols() info, err := lp.convert2PhysicalPlan(&requiredProperty{}) @@ -438,12 +438,12 @@ func (s *testPlanSuite) TestCBO(c *C) { best: "Table(t)->HashAgg->Sort + Limit(1) + Offset(0)->Projection", }, { - sql: "select count(*) from t where concat(a,b) = 'abc' group by a", - best: "Table(t)->Selection->StreamAgg", + sql: "select count(*) from t where concat(a,b) = 'abc' group by c", + best: "Index(t.c_d_e)[[,+inf]]->Selection->StreamAgg", }, { sql: "select count(*) from t where concat(a,b) = 'abc' group by a order by a", - best: "Table(t)->Selection->StreamAgg->Projection", + best: "Table(t)->Selection->Projection->Projection", }, { sql: "select count(distinct e) from t where c = 1 and concat(c,d) = 'abc' group by d", @@ -508,7 +508,7 @@ func (s *testPlanSuite) TestCBO(c *C) { }, { sql: "select * from (select t.a from t union select t.d from t where t.c = 1 union select t.c from t) k order by a limit 1", - best: "UnionAll{Table(t)->HashAgg->Index(t.c_d_e)[[1,1]]->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)", + best: "UnionAll{Table(t)->Projection->Index(t.c_d_e)[[1,1]]->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)", }, { sql: "select * from (select t.a from t union all select t.d from t where t.c = 1 union all select t.c from t) k order by a limit 1", @@ -516,7 +516,11 @@ func (s *testPlanSuite) TestCBO(c *C) { }, { sql: "select * from (select t.a from t union select t.d from t union select t.c from t) k order by a limit 1", - best: "UnionAll{Table(t)->HashAgg->Table(t)->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)", + best: "UnionAll{Table(t)->Projection->Table(t)->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)", + }, + { + sql: "select t.c from t where 0 = (select count(b) from t t1 where t.a = t1.b)", + best: "LeftHashJoin{Table(t)->Table(t)->HashAgg}(test.t.a,t1.b)->Projection->Selection->Projection", }, } for _, ca := range cases { @@ -536,7 +540,7 @@ func (s *testPlanSuite) TestCBO(c *C) { p := builder.build(stmt) c.Assert(builder.err, IsNil) lp := p.(LogicalPlan) - lp, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagAggPushDown|flagDecorrelate, lp, builder.ctx, builder.allocator) + lp, err = logicalOptimize(flagPredicatePushDown|flagBuildKeyInfo|flagPrunColumns|flagAggregationOptimize|flagDecorrelate, lp, builder.ctx, builder.allocator) lp.ResolveIndicesAndCorCols() info, err := lp.convert2PhysicalPlan(&requiredProperty{}) c.Assert(err, IsNil)