plan: combine agg_prune and agg_pushdown. (#2820)
This commit is contained in:
@ -280,10 +280,10 @@ func (s *testSuite) TestExplain(c *C) {
|
||||
{
|
||||
"select count(b.c2) from t1 a, t2 b where a.c1 = b.c2 group by a.c1",
|
||||
[]string{
|
||||
"TableScan_10", "TableScan_11", "HashAgg_12", "HashLeftJoin_9", "HashAgg_17",
|
||||
"TableScan_11", "TableScan_12", "HashAgg_13", "HashLeftJoin_10", "Projection_9",
|
||||
},
|
||||
[]string{
|
||||
"HashLeftJoin_9", "HashAgg_12", "HashLeftJoin_9", "HashAgg_17", "",
|
||||
"HashLeftJoin_10", "HashAgg_13", "HashLeftJoin_10", "Projection_9", "",
|
||||
},
|
||||
[]string{`{
|
||||
"db": "test",
|
||||
@ -325,7 +325,7 @@ func (s *testSuite) TestExplain(c *C) {
|
||||
"GroupByItems": [
|
||||
"[b.c2]"
|
||||
],
|
||||
"child": "TableScan_11"
|
||||
"child": "TableScan_12"
|
||||
}`,
|
||||
`{
|
||||
"eqCond": [
|
||||
@ -334,17 +334,14 @@ func (s *testSuite) TestExplain(c *C) {
|
||||
"leftCond": null,
|
||||
"rightCond": null,
|
||||
"otherCond": null,
|
||||
"leftPlan": "TableScan_10",
|
||||
"rightPlan": "HashAgg_12"
|
||||
"leftPlan": "TableScan_11",
|
||||
"rightPlan": "HashAgg_13"
|
||||
}`,
|
||||
`{
|
||||
"AggFuncs": [
|
||||
"count(join_agg_0)"
|
||||
"exprs": [
|
||||
"cast(join_agg_0)"
|
||||
],
|
||||
"GroupByItems": [
|
||||
"a.c1"
|
||||
],
|
||||
"child": "HashLeftJoin_9"
|
||||
"child": "HashLeftJoin_10"
|
||||
}`,
|
||||
},
|
||||
},
|
||||
|
||||
@ -1,116 +0,0 @@
|
||||
// Copyright 2017 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
// // Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package plan
|
||||
|
||||
import (
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
type aggPruner struct {
|
||||
allocator *idAllocator
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (ap *aggPruner) optimize(lp LogicalPlan, ctx context.Context, allocator *idAllocator) (LogicalPlan, error) {
|
||||
ap.ctx = ctx
|
||||
ap.allocator = allocator
|
||||
return ap.eliminateAggregation(lp), nil
|
||||
}
|
||||
|
||||
// eliminateAggregation will eliminate aggregation grouped by unique key.
|
||||
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
|
||||
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
|
||||
func (ap *aggPruner) eliminateAggregation(p LogicalPlan) LogicalPlan {
|
||||
retPlan := p
|
||||
if agg, ok := p.(*Aggregation); ok {
|
||||
schemaByGroupby := expression.NewSchema(agg.groupByCols...)
|
||||
coveredByUniqueKey := false
|
||||
for _, key := range agg.children[0].Schema().Keys {
|
||||
if schemaByGroupby.ColumnsIndices(key) != nil {
|
||||
coveredByUniqueKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if coveredByUniqueKey {
|
||||
// GroupByCols has unique key, so this aggregation can be removed.
|
||||
proj := convertAggToProj(agg, ap.ctx, ap.allocator)
|
||||
proj.SetParents(p.Parents()...)
|
||||
for _, child := range p.Children() {
|
||||
child.SetParents(proj)
|
||||
}
|
||||
retPlan = proj
|
||||
}
|
||||
}
|
||||
newChildren := make([]Plan, 0, len(p.Children()))
|
||||
for _, child := range p.Children() {
|
||||
newChild := ap.eliminateAggregation(child.(LogicalPlan))
|
||||
newChildren = append(newChildren, newChild)
|
||||
}
|
||||
retPlan.SetChildren(newChildren...)
|
||||
return retPlan
|
||||
}
|
||||
|
||||
func convertAggToProj(agg *Aggregation, ctx context.Context, allocator *idAllocator) *Projection {
|
||||
proj := &Projection{
|
||||
Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Proj, allocator),
|
||||
}
|
||||
proj.self = proj
|
||||
proj.initIDAndContext(ctx)
|
||||
for _, fun := range agg.AggFuncs {
|
||||
expr := rewriteExpr(fun.GetArgs(), fun.GetName(), ctx)
|
||||
proj.Exprs = append(proj.Exprs, expr)
|
||||
}
|
||||
proj.SetSchema(agg.schema.Clone())
|
||||
return proj
|
||||
}
|
||||
|
||||
// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function.
|
||||
func rewriteExpr(exprs []expression.Expression, funcName string, ctx context.Context) (newExpr expression.Expression) {
|
||||
switch funcName {
|
||||
case ast.AggFuncCount:
|
||||
// If is count(expr), we will change it to if(isnull(expr), 0, 1).
|
||||
// If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1).
|
||||
isNullExprs := make([]expression.Expression, 0, len(exprs))
|
||||
for _, expr := range exprs {
|
||||
isNullExpr, _ := expression.NewFunction(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr.Clone())
|
||||
isNullExprs = append(isNullExprs, isNullExpr)
|
||||
}
|
||||
innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...)
|
||||
newExpr, _ = expression.NewFunction(ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One)
|
||||
// See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
|
||||
// The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL),
|
||||
// and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE).
|
||||
case ast.AggFuncSum, ast.AggFuncAvg:
|
||||
expr := exprs[0].Clone()
|
||||
switch expr.GetType().Tp {
|
||||
// Integer type should be cast to decimal.
|
||||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
|
||||
newExpr = expression.NewCastFunc(types.NewFieldType(mysql.TypeNewDecimal), expr, ctx)
|
||||
// Double and Decimal doesn't need to be cast.
|
||||
case mysql.TypeDouble, mysql.TypeNewDecimal:
|
||||
newExpr = expr
|
||||
// Float should be cast to double. And other non-numeric type should be cast to double too.
|
||||
default:
|
||||
newExpr = expression.NewCastFunc(types.NewFieldType(mysql.TypeDouble), expr, ctx)
|
||||
}
|
||||
default:
|
||||
// Default we do nothing about expr.
|
||||
newExpr = exprs[0].Clone()
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -23,9 +23,9 @@ import (
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
type aggPushDownSolver struct {
|
||||
alloc *idAllocator
|
||||
ctx context.Context
|
||||
type aggregationOptimizer struct {
|
||||
allocator *idAllocator
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// isDecomposable checks if an aggregate function is decomposable. An aggregation function $F$ is decomposable
|
||||
@ -34,7 +34,7 @@ type aggPushDownSolver struct {
|
||||
// It's easy to see that max, min, first row is decomposable, no matter whether it's distinct, but sum(distinct) and
|
||||
// count(distinct) is not.
|
||||
// Currently we don't support avg and concat.
|
||||
func (a *aggPushDownSolver) isDecomposable(fun expression.AggregationFunction) bool {
|
||||
func (a *aggregationOptimizer) isDecomposable(fun expression.AggregationFunction) bool {
|
||||
switch fun.GetName() {
|
||||
case ast.AggFuncAvg, ast.AggFuncGroupConcat:
|
||||
// TODO: Support avg push down.
|
||||
@ -49,7 +49,7 @@ func (a *aggPushDownSolver) isDecomposable(fun expression.AggregationFunction) b
|
||||
}
|
||||
|
||||
// getAggFuncChildIdx gets which children it belongs to, 0 stands for left, 1 stands for right, -1 stands for both.
|
||||
func (a *aggPushDownSolver) getAggFuncChildIdx(aggFunc expression.AggregationFunction, schema *expression.Schema) int {
|
||||
func (a *aggregationOptimizer) getAggFuncChildIdx(aggFunc expression.AggregationFunction, schema *expression.Schema) int {
|
||||
fromLeft, fromRight := false, false
|
||||
var cols []*expression.Column
|
||||
for _, arg := range aggFunc.GetArgs() {
|
||||
@ -73,7 +73,7 @@ func (a *aggPushDownSolver) getAggFuncChildIdx(aggFunc expression.AggregationFun
|
||||
// collectAggFuncs collects all aggregate functions and splits them into two parts: "leftAggFuncs" and "rightAggFuncs" whose
|
||||
// arguments are all from left child or right child separately. If some aggregate functions have the arguments that have
|
||||
// columns both from left and right children, the whole aggregation is forbidden to push down.
|
||||
func (a *aggPushDownSolver) collectAggFuncs(agg *Aggregation, join *Join) (valid bool, leftAggFuncs, rightAggFuncs []expression.AggregationFunction) {
|
||||
func (a *aggregationOptimizer) collectAggFuncs(agg *Aggregation, join *Join) (valid bool, leftAggFuncs, rightAggFuncs []expression.AggregationFunction) {
|
||||
valid = true
|
||||
leftChild := join.children[0]
|
||||
for _, aggFunc := range agg.AggFuncs {
|
||||
@ -98,7 +98,7 @@ func (a *aggPushDownSolver) collectAggFuncs(agg *Aggregation, join *Join) (valid
|
||||
// query should be "SELECT SUM(B.agg) FROM A, (SELECT SUM(id) as agg, c1, c2, c3 FROM B GROUP BY id, c1, c2, c3) as B
|
||||
// WHERE A.c1 = B.c1 AND A.c2 != B.c2 GROUP BY B.c3". As you see, all the columns appearing in join-conditions should be
|
||||
// treated as group by columns in join subquery.
|
||||
func (a *aggPushDownSolver) collectGbyCols(agg *Aggregation, join *Join) (leftGbyCols, rightGbyCols []*expression.Column) {
|
||||
func (a *aggregationOptimizer) collectGbyCols(agg *Aggregation, join *Join) (leftGbyCols, rightGbyCols []*expression.Column) {
|
||||
leftChild := join.children[0]
|
||||
for _, gbyExpr := range agg.GroupByItems {
|
||||
cols := expression.ExtractColumns(gbyExpr)
|
||||
@ -136,7 +136,7 @@ func (a *aggPushDownSolver) collectGbyCols(agg *Aggregation, join *Join) (leftGb
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aggPushDownSolver) splitAggFuncsAndGbyCols(agg *Aggregation, join *Join) (valid bool,
|
||||
func (a *aggregationOptimizer) splitAggFuncsAndGbyCols(agg *Aggregation, join *Join) (valid bool,
|
||||
leftAggFuncs, rightAggFuncs []expression.AggregationFunction,
|
||||
leftGbyCols, rightGbyCols []*expression.Column) {
|
||||
valid, leftAggFuncs, rightAggFuncs = a.collectAggFuncs(agg, join)
|
||||
@ -148,7 +148,7 @@ func (a *aggPushDownSolver) splitAggFuncsAndGbyCols(agg *Aggregation, join *Join
|
||||
}
|
||||
|
||||
// addGbyCol adds a column to gbyCols. If a group by column has existed, it will not be added repeatedly.
|
||||
func (a *aggPushDownSolver) addGbyCol(gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column {
|
||||
func (a *aggregationOptimizer) addGbyCol(gbyCols []*expression.Column, cols ...*expression.Column) []*expression.Column {
|
||||
for _, c := range cols {
|
||||
duplicate := false
|
||||
for _, gbyCol := range gbyCols {
|
||||
@ -165,13 +165,13 @@ func (a *aggPushDownSolver) addGbyCol(gbyCols []*expression.Column, cols ...*exp
|
||||
}
|
||||
|
||||
// checkValidJoin checks if this join should be pushed across.
|
||||
func (a *aggPushDownSolver) checkValidJoin(join *Join) bool {
|
||||
func (a *aggregationOptimizer) checkValidJoin(join *Join) bool {
|
||||
return join.JoinType == InnerJoin || join.JoinType == LeftOuterJoin || join.JoinType == RightOuterJoin
|
||||
}
|
||||
|
||||
// decompose splits an aggregate function to two parts: a final mode function and a partial mode function. Currently
|
||||
// there are no differences between partial mode and complete mode, so we can confuse them.
|
||||
func (a *aggPushDownSolver) decompose(aggFunc expression.AggregationFunction, schema *expression.Schema, id string) ([]expression.AggregationFunction, *expression.Schema) {
|
||||
func (a *aggregationOptimizer) decompose(aggFunc expression.AggregationFunction, schema *expression.Schema, id string) ([]expression.AggregationFunction, *expression.Schema) {
|
||||
// Result is a slice because avg should be decomposed to sum and count. Currently we don't process this case.
|
||||
result := []expression.AggregationFunction{aggFunc.Clone()}
|
||||
for _, aggFunc := range result {
|
||||
@ -187,7 +187,7 @@ func (a *aggPushDownSolver) decompose(aggFunc expression.AggregationFunction, sc
|
||||
return result, schema
|
||||
}
|
||||
|
||||
func (a *aggPushDownSolver) allFirstRow(aggFuncs []expression.AggregationFunction) bool {
|
||||
func (a *aggregationOptimizer) allFirstRow(aggFuncs []expression.AggregationFunction) bool {
|
||||
for _, fun := range aggFuncs {
|
||||
if fun.GetName() != ast.AggFuncFirstRow {
|
||||
return false
|
||||
@ -199,7 +199,7 @@ func (a *aggPushDownSolver) allFirstRow(aggFuncs []expression.AggregationFunctio
|
||||
// tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't
|
||||
// process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator.
|
||||
// If the pushed aggregation is grouped by unique key, it's no need to push it down.
|
||||
func (a *aggPushDownSolver) tryToPushDownAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column, join *Join, childIdx int) LogicalPlan {
|
||||
func (a *aggregationOptimizer) tryToPushDownAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column, join *Join, childIdx int) LogicalPlan {
|
||||
child := join.children[childIdx].(LogicalPlan)
|
||||
if a.allFirstRow(aggFuncs) {
|
||||
return child
|
||||
@ -230,7 +230,7 @@ func (a *aggPushDownSolver) tryToPushDownAgg(aggFuncs []expression.AggregationFu
|
||||
return agg
|
||||
}
|
||||
|
||||
func (a *aggPushDownSolver) getDefaultValues(agg *Aggregation) ([]types.Datum, bool) {
|
||||
func (a *aggregationOptimizer) getDefaultValues(agg *Aggregation) ([]types.Datum, bool) {
|
||||
defaultValues := make([]types.Datum, 0, agg.Schema().Len())
|
||||
for _, aggFunc := range agg.AggFuncs {
|
||||
value, existsDefaultValue := aggFunc.CalculateDefaultValue(agg.children[0].Schema(), a.ctx)
|
||||
@ -242,7 +242,7 @@ func (a *aggPushDownSolver) getDefaultValues(agg *Aggregation) ([]types.Datum, b
|
||||
return defaultValues, true
|
||||
}
|
||||
|
||||
func (a *aggPushDownSolver) checkAnyCountAndSum(aggFuncs []expression.AggregationFunction) bool {
|
||||
func (a *aggregationOptimizer) checkAnyCountAndSum(aggFuncs []expression.AggregationFunction) bool {
|
||||
for _, fun := range aggFuncs {
|
||||
if fun.GetName() == ast.AggFuncSum || fun.GetName() == ast.AggFuncCount {
|
||||
return true
|
||||
@ -251,10 +251,10 @@ func (a *aggPushDownSolver) checkAnyCountAndSum(aggFuncs []expression.Aggregatio
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *aggPushDownSolver) makeNewAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column) *Aggregation {
|
||||
func (a *aggregationOptimizer) makeNewAgg(aggFuncs []expression.AggregationFunction, gbyCols []*expression.Column) *Aggregation {
|
||||
agg := &Aggregation{
|
||||
GroupByItems: expression.Column2Exprs(gbyCols),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Agg, a.allocator),
|
||||
groupByCols: gbyCols,
|
||||
}
|
||||
agg.initIDAndContext(a.ctx)
|
||||
@ -277,11 +277,11 @@ func (a *aggPushDownSolver) makeNewAgg(aggFuncs []expression.AggregationFunction
|
||||
|
||||
// pushAggCrossUnion will try to push the agg down to the union. If the new aggregation's group-by columns doesn't contain unique key.
|
||||
// We will return the new aggregation. Otherwise we will transform the aggregation to projection.
|
||||
func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan {
|
||||
func (a *aggregationOptimizer) pushAggCrossUnion(agg *Aggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan {
|
||||
newAgg := &Aggregation{
|
||||
AggFuncs: make([]expression.AggregationFunction, 0, len(agg.AggFuncs)),
|
||||
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Agg, a.alloc),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Agg, a.allocator),
|
||||
}
|
||||
newAgg.SetSchema(agg.schema.Clone())
|
||||
newAgg.initIDAndContext(a.ctx)
|
||||
@ -305,7 +305,7 @@ func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema *exp
|
||||
// this will cause error during executor phase.
|
||||
for _, key := range unionChild.Schema().Keys {
|
||||
if tmpSchema.ColumnsIndices(key) != nil {
|
||||
proj := convertAggToProj(newAgg, a.ctx, a.alloc)
|
||||
proj := a.convertAggToProj(newAgg, a.ctx, a.allocator)
|
||||
proj.SetChildren(unionChild)
|
||||
return proj
|
||||
}
|
||||
@ -315,72 +315,173 @@ func (a *aggPushDownSolver) pushAggCrossUnion(agg *Aggregation, unionSchema *exp
|
||||
return newAgg
|
||||
}
|
||||
|
||||
func (a *aggPushDownSolver) optimize(p LogicalPlan, ctx context.Context, alloc *idAllocator) (LogicalPlan, error) {
|
||||
func (a *aggregationOptimizer) optimize(p LogicalPlan, ctx context.Context, alloc *idAllocator) (LogicalPlan, error) {
|
||||
if !ctx.GetSessionVars().AllowAggPushDown {
|
||||
return p, nil
|
||||
}
|
||||
a.ctx = ctx
|
||||
a.alloc = alloc
|
||||
a.allocator = alloc
|
||||
a.aggPushDown(p)
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// aggPushDown tries to push down aggregate functions to join paths.
|
||||
func (a *aggPushDownSolver) aggPushDown(p LogicalPlan) {
|
||||
func (a *aggregationOptimizer) aggPushDown(p LogicalPlan) LogicalPlan {
|
||||
if agg, ok := p.(*Aggregation); ok {
|
||||
child := agg.children[0]
|
||||
if join, ok1 := child.(*Join); ok1 && a.checkValidJoin(join) {
|
||||
if valid, leftAggFuncs, rightAggFuncs, leftGbyCols, rightGbyCols := a.splitAggFuncsAndGbyCols(agg, join); valid {
|
||||
var lChild, rChild LogicalPlan
|
||||
// If there exist count or sum functions in left join path, we can't push any
|
||||
// aggregate function into right join path.
|
||||
rightInvalid := a.checkAnyCountAndSum(leftAggFuncs)
|
||||
leftInvalid := a.checkAnyCountAndSum(rightAggFuncs)
|
||||
if rightInvalid {
|
||||
rChild = join.children[1].(LogicalPlan)
|
||||
} else {
|
||||
rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1)
|
||||
proj := a.tryToEliminateAggregation(agg)
|
||||
if proj != nil {
|
||||
p = proj
|
||||
} else {
|
||||
child := agg.children[0]
|
||||
if join, ok1 := child.(*Join); ok1 && a.checkValidJoin(join) {
|
||||
if valid, leftAggFuncs, rightAggFuncs, leftGbyCols, rightGbyCols := a.splitAggFuncsAndGbyCols(agg, join); valid {
|
||||
var lChild, rChild LogicalPlan
|
||||
// If there exist count or sum functions in left join path, we can't push any
|
||||
// aggregate function into right join path.
|
||||
rightInvalid := a.checkAnyCountAndSum(leftAggFuncs)
|
||||
leftInvalid := a.checkAnyCountAndSum(rightAggFuncs)
|
||||
if rightInvalid {
|
||||
rChild = join.children[1].(LogicalPlan)
|
||||
} else {
|
||||
rChild = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1)
|
||||
}
|
||||
if leftInvalid {
|
||||
lChild = join.children[0].(LogicalPlan)
|
||||
} else {
|
||||
lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
|
||||
}
|
||||
join.SetChildren(lChild, rChild)
|
||||
lChild.SetParents(join)
|
||||
rChild.SetParents(join)
|
||||
join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema()))
|
||||
join.buildKeyInfo()
|
||||
proj := a.tryToEliminateAggregation(agg)
|
||||
if proj != nil {
|
||||
p = proj
|
||||
}
|
||||
}
|
||||
if leftInvalid {
|
||||
lChild = join.children[0].(LogicalPlan)
|
||||
} else {
|
||||
lChild = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
|
||||
} else if proj, ok1 := child.(*Projection); ok1 {
|
||||
// TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet,
|
||||
// so we must do this optimization.
|
||||
for i, gbyItem := range agg.GroupByItems {
|
||||
agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)
|
||||
}
|
||||
join.SetChildren(lChild, rChild)
|
||||
lChild.SetParents(join)
|
||||
rChild.SetParents(join)
|
||||
join.SetSchema(expression.MergeSchema(lChild.Schema(), rChild.Schema()))
|
||||
}
|
||||
} else if proj, ok1 := child.(*Projection); ok1 {
|
||||
// TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet,
|
||||
// so we must do this optimization.
|
||||
for i, gbyItem := range agg.GroupByItems {
|
||||
agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)
|
||||
}
|
||||
agg.collectGroupByColumns()
|
||||
for _, aggFunc := range agg.AggFuncs {
|
||||
newArgs := make([]expression.Expression, 0, len(aggFunc.GetArgs()))
|
||||
for _, arg := range aggFunc.GetArgs() {
|
||||
newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs))
|
||||
agg.collectGroupByColumns()
|
||||
for _, aggFunc := range agg.AggFuncs {
|
||||
newArgs := make([]expression.Expression, 0, len(aggFunc.GetArgs()))
|
||||
for _, arg := range aggFunc.GetArgs() {
|
||||
newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs))
|
||||
}
|
||||
aggFunc.SetArgs(newArgs)
|
||||
}
|
||||
aggFunc.SetArgs(newArgs)
|
||||
projChild := proj.children[0]
|
||||
agg.SetChildren(projChild)
|
||||
projChild.SetParents(agg)
|
||||
} else if union, ok1 := child.(*Union); ok1 {
|
||||
pushedAgg := a.makeNewAgg(agg.AggFuncs, agg.groupByCols)
|
||||
newChildren := make([]Plan, 0, len(union.children))
|
||||
for _, child := range union.children {
|
||||
newChild := a.pushAggCrossUnion(pushedAgg, union.schema, child.(LogicalPlan))
|
||||
newChildren = append(newChildren, newChild)
|
||||
newChild.SetParents(union)
|
||||
}
|
||||
union.SetChildren(newChildren...)
|
||||
union.SetSchema(pushedAgg.schema)
|
||||
}
|
||||
projChild := proj.children[0]
|
||||
agg.SetChildren(projChild)
|
||||
projChild.SetParents(agg)
|
||||
} else if union, ok1 := child.(*Union); ok1 {
|
||||
pushedAgg := a.makeNewAgg(agg.AggFuncs, agg.groupByCols)
|
||||
newChildren := make([]Plan, 0, len(union.children))
|
||||
for _, child := range union.children {
|
||||
newChild := a.pushAggCrossUnion(pushedAgg, union.schema, child.(LogicalPlan))
|
||||
newChildren = append(newChildren, newChild)
|
||||
newChild.SetParents(union)
|
||||
}
|
||||
union.SetChildren(newChildren...)
|
||||
union.SetSchema(pushedAgg.schema)
|
||||
}
|
||||
}
|
||||
newChildren := make([]Plan, 0, len(p.Children()))
|
||||
for _, child := range p.Children() {
|
||||
a.aggPushDown(child.(LogicalPlan))
|
||||
newChild := a.aggPushDown(child.(LogicalPlan))
|
||||
newChild.SetParents(p)
|
||||
newChildren = append(newChildren, newChild)
|
||||
}
|
||||
p.SetChildren(newChildren...)
|
||||
return p
|
||||
}
|
||||
|
||||
// tryToEliminateAggregation will eliminate aggregation grouped by unique key.
|
||||
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
|
||||
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
|
||||
// If we can eliminate agg successful, we return a projection. Else we return a nil pointer.
|
||||
func (a *aggregationOptimizer) tryToEliminateAggregation(agg *Aggregation) *Projection {
|
||||
schemaByGroupby := expression.NewSchema(agg.groupByCols...)
|
||||
coveredByUniqueKey := false
|
||||
for _, key := range agg.children[0].Schema().Keys {
|
||||
if schemaByGroupby.ColumnsIndices(key) != nil {
|
||||
coveredByUniqueKey = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if coveredByUniqueKey {
|
||||
// GroupByCols has unique key, so this aggregation can be removed.
|
||||
proj := a.convertAggToProj(agg, a.ctx, a.allocator)
|
||||
proj.SetChildren(agg.children[0])
|
||||
agg.children[0].SetParents(proj)
|
||||
return proj
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *aggregationOptimizer) convertAggToProj(agg *Aggregation, ctx context.Context, allocator *idAllocator) *Projection {
|
||||
proj := &Projection{
|
||||
Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Proj, allocator),
|
||||
}
|
||||
proj.self = proj
|
||||
proj.initIDAndContext(ctx)
|
||||
for _, fun := range agg.AggFuncs {
|
||||
expr := a.rewriteExpr(fun)
|
||||
proj.Exprs = append(proj.Exprs, expr)
|
||||
}
|
||||
proj.SetSchema(agg.schema.Clone())
|
||||
return proj
|
||||
}
|
||||
|
||||
func (a *aggregationOptimizer) rewriteCount(exprs []expression.Expression) expression.Expression {
|
||||
// If is count(expr), we will change it to if(isnull(expr), 0, 1).
|
||||
// If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1).
|
||||
isNullExprs := make([]expression.Expression, 0, len(exprs))
|
||||
for _, expr := range exprs {
|
||||
isNullExpr, _ := expression.NewFunction(a.ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr.Clone())
|
||||
isNullExprs = append(isNullExprs, isNullExpr)
|
||||
}
|
||||
innerExpr := expression.ComposeDNFCondition(a.ctx, isNullExprs...)
|
||||
newExpr, _ := expression.NewFunction(a.ctx, ast.If, types.NewFieldType(mysql.TypeLonglong), innerExpr, expression.Zero, expression.One)
|
||||
return newExpr
|
||||
}
|
||||
|
||||
// See https://dev.mysql.com/doc/refman/5.7/en/group-by-functions.html
|
||||
// The SUM() and AVG() functions return a DECIMAL value for exact-value arguments (integer or DECIMAL),
|
||||
// and a DOUBLE value for approximate-value arguments (FLOAT or DOUBLE).
|
||||
func (a *aggregationOptimizer) rewriteSumOrAvg(exprs []expression.Expression) expression.Expression {
|
||||
// FIXME: Consider the case that avg is final mode.
|
||||
expr := exprs[0].Clone()
|
||||
switch expr.GetType().Tp {
|
||||
// Integer type should be cast to decimal.
|
||||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong:
|
||||
return expression.NewCastFunc(types.NewFieldType(mysql.TypeNewDecimal), expr, a.ctx)
|
||||
// Double and Decimal doesn't need to be cast.
|
||||
case mysql.TypeDouble, mysql.TypeNewDecimal:
|
||||
return expr
|
||||
// Float should be cast to double. And other non-numeric type should be cast to double too.
|
||||
default:
|
||||
return expression.NewCastFunc(types.NewFieldType(mysql.TypeDouble), expr, a.ctx)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function.
|
||||
func (a *aggregationOptimizer) rewriteExpr(aggFunc expression.AggregationFunction) expression.Expression {
|
||||
switch aggFunc.GetName() {
|
||||
case ast.AggFuncCount:
|
||||
if aggFunc.GetMode() == expression.FinalMode {
|
||||
return a.rewriteSumOrAvg(aggFunc.GetArgs())
|
||||
}
|
||||
return a.rewriteCount(aggFunc.GetArgs())
|
||||
case ast.AggFuncSum, ast.AggFuncAvg:
|
||||
return a.rewriteSumOrAvg(aggFunc.GetArgs())
|
||||
default:
|
||||
// Default we do nothing about expr.
|
||||
return aggFunc.GetArgs()[0].Clone()
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,6 +40,16 @@ func (p *Aggregation) buildKeyInfo() {
|
||||
}
|
||||
p.schema.Keys = append(p.schema.Keys, newKey)
|
||||
}
|
||||
if len(p.groupByCols) == len(p.GroupByItems) && len(p.GroupByItems) > 0 {
|
||||
indices := p.schema.ColumnsIndices(p.groupByCols)
|
||||
if indices != nil {
|
||||
newKey := make([]*expression.Column, 0, len(indices))
|
||||
for _, i := range indices {
|
||||
newKey = append(newKey, p.schema.Columns[i])
|
||||
}
|
||||
p.schema.Keys = append(p.schema.Keys, newKey)
|
||||
}
|
||||
}
|
||||
if len(p.GroupByItems) == 0 {
|
||||
p.schema.MaxOneRow = true
|
||||
}
|
||||
|
||||
@ -153,6 +153,7 @@ func (s *decorrelateSolver) optimize(p LogicalPlan, _ context.Context, _ *idAllo
|
||||
np, _ := s.optimize(p, nil, nil)
|
||||
agg.SetChildren(np)
|
||||
np.SetParents(agg)
|
||||
agg.collectGroupByColumns()
|
||||
return agg, nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -47,8 +47,7 @@ func (p *Aggregation) collectGroupByColumns() {
|
||||
|
||||
func (b *planBuilder) buildAggregation(p LogicalPlan, aggFuncList []*ast.AggregateFuncExpr, gbyItems []expression.Expression) (LogicalPlan, map[int]int) {
|
||||
b.optFlag = b.optFlag | flagBuildKeyInfo
|
||||
b.optFlag = b.optFlag | flagEliminateAgg
|
||||
b.optFlag = b.optFlag | flagAggPushDown
|
||||
b.optFlag = b.optFlag | flagAggregationOptimize
|
||||
agg := &Aggregation{
|
||||
AggFuncs: make([]expression.AggregationFunction, 0, len(aggFuncList)),
|
||||
baseLogicalPlan: newBaseLogicalPlan(Agg, b.allocator)}
|
||||
@ -325,8 +324,7 @@ func (b *planBuilder) buildProjection(p LogicalPlan, fields []*ast.SelectField,
|
||||
|
||||
func (b *planBuilder) buildDistinct(child LogicalPlan, length int) LogicalPlan {
|
||||
b.optFlag = b.optFlag | flagBuildKeyInfo
|
||||
b.optFlag = b.optFlag | flagEliminateAgg
|
||||
b.optFlag = b.optFlag | flagAggPushDown
|
||||
b.optFlag = b.optFlag | flagAggregationOptimize
|
||||
agg := &Aggregation{
|
||||
baseLogicalPlan: newBaseLogicalPlan(Agg, b.allocator),
|
||||
AggFuncs: make([]expression.AggregationFunction, 0, child.Schema().Len()),
|
||||
|
||||
@ -743,7 +743,7 @@ func (s *testPlanSuite) TestAggPushDown(c *C) {
|
||||
},
|
||||
{
|
||||
sql: "select max(a.b), max(b.b) from t a join t b on a.c = b.c group by a.a",
|
||||
best: "Join{DataScan(a)->DataScan(b)->Aggr(max(b.b),firstrow(b.c))}(a.c,b.c)->Aggr(max(a.b),max(join_agg_0))->Projection",
|
||||
best: "Join{DataScan(a)->DataScan(b)->Aggr(max(b.b),firstrow(b.c))}(a.c,b.c)->Projection->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select max(a.b), max(b.b) from t a join t b on a.a = b.a group by a.c",
|
||||
@ -775,7 +775,7 @@ func (s *testPlanSuite) TestAggPushDown(c *C) {
|
||||
p := builder.build(stmt)
|
||||
c.Assert(builder.err, IsNil)
|
||||
lp := p.(LogicalPlan)
|
||||
p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagAggPushDown, lp.(LogicalPlan), builder.ctx, builder.allocator)
|
||||
p, err = logicalOptimize(flagBuildKeyInfo|flagPredicatePushDown|flagPrunColumns|flagAggregationOptimize, lp.(LogicalPlan), builder.ctx, builder.allocator)
|
||||
lp.ResolveIndicesAndCorCols()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(ToString(lp), Equals, ca.best, Commentf("for %s", ca.sql))
|
||||
@ -1290,7 +1290,7 @@ func (s *testPlanSuite) TestUniqueKeyInfo(c *C) {
|
||||
ans map[string][][]string
|
||||
}{
|
||||
{
|
||||
sql: "select a, sum(e) from t group by a",
|
||||
sql: "select a, sum(e) from t group by b",
|
||||
ans: map[string][][]string{
|
||||
"TableScan_1": {{"test.t.a"}},
|
||||
"Aggregation_2": {{"test.t.a"}},
|
||||
@ -1298,23 +1298,23 @@ func (s *testPlanSuite) TestUniqueKeyInfo(c *C) {
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: "select a, sum(f) from t group by a",
|
||||
sql: "select a, b, sum(f) from t group by b",
|
||||
ans: map[string][][]string{
|
||||
"TableScan_1": {{"test.t.f"}, {"test.t.a"}},
|
||||
"Aggregation_2": {{"test.t.a"}},
|
||||
"Projection_3": {{"a"}},
|
||||
"Aggregation_2": {{"test.t.a"}, {"test.t.b"}},
|
||||
"Projection_3": {{"a"}, {"b"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: "select c, d, e, sum(a) from t group by c, d, e",
|
||||
ans: map[string][][]string{
|
||||
"TableScan_1": {{"test.t.a"}},
|
||||
"Aggregation_2": nil,
|
||||
"Projection_3": nil,
|
||||
"Aggregation_2": {{"test.t.c", "test.t.d", "test.t.e"}},
|
||||
"Projection_3": {{"c", "d", "e"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
sql: "select f, g, sum(a) from t group by f, g",
|
||||
sql: "select f, g, sum(a) from t",
|
||||
ans: map[string][][]string{
|
||||
"TableScan_1": {{"test.t.f"}, {"test.t.g"}, {"test.t.f", "test.t.g"}, {"test.t.a"}},
|
||||
"Aggregation_2": {{"test.t.f"}, {"test.t.g"}, {"test.t.f", "test.t.g"}},
|
||||
@ -1394,6 +1394,10 @@ func (s *testPlanSuite) TestAggPrune(c *C) {
|
||||
sql: "select tt.a, sum(tt.b) from (select a, b from t) tt group by tt.a",
|
||||
best: "DataScan(t)->Projection->Projection->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select count(1) from (select count(1), a as b from t group by a) tt group by b",
|
||||
best: "DataScan(t)->Projection->Projection->Projection->Projection",
|
||||
},
|
||||
}
|
||||
for _, ca := range cases {
|
||||
comment := Commentf("for %s", ca.sql)
|
||||
@ -1410,7 +1414,7 @@ func (s *testPlanSuite) TestAggPrune(c *C) {
|
||||
}
|
||||
p := builder.build(stmt).(LogicalPlan)
|
||||
c.Assert(builder.err, IsNil)
|
||||
p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagEliminateAgg, p.(LogicalPlan), builder.ctx, builder.allocator)
|
||||
p, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagBuildKeyInfo|flagAggregationOptimize, p.(LogicalPlan), builder.ctx, builder.allocator)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(ToString(p), Equals, ca.best, comment)
|
||||
}
|
||||
|
||||
@ -15,6 +15,7 @@ package plan
|
||||
|
||||
import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/ngaut/log"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
@ -35,8 +36,7 @@ const (
|
||||
flagBuildKeyInfo
|
||||
flagDecorrelate
|
||||
flagPredicatePushDown
|
||||
flagEliminateAgg
|
||||
flagAggPushDown
|
||||
flagAggregationOptimize
|
||||
)
|
||||
|
||||
var optRuleList = []logicalOptRule{
|
||||
@ -44,8 +44,7 @@ var optRuleList = []logicalOptRule{
|
||||
&buildKeySolver{},
|
||||
&decorrelateSolver{},
|
||||
&ppdSolver{},
|
||||
&aggPruner{},
|
||||
&aggPushDownSolver{},
|
||||
&aggregationOptimizer{},
|
||||
}
|
||||
|
||||
// logicalOptRule means a logical optimizing rule, which contains decorrelate, ppd, column pruning, etc.
|
||||
@ -104,6 +103,7 @@ func doOptimize(flag uint64, logic LogicalPlan, ctx context.Context, allocator *
|
||||
return nil, errors.Trace(ErrCartesianProductUnsupported)
|
||||
}
|
||||
logic.ResolveIndicesAndCorCols()
|
||||
log.Warnf("PLAN %s", ToString(logic))
|
||||
return physicalOptimize(flag, logic, allocator)
|
||||
}
|
||||
|
||||
|
||||
@ -85,7 +85,7 @@ func (s *testPlanSuite) TestPushDownAggregation(c *C) {
|
||||
_, lp, err = lp.PredicatePushDown(nil)
|
||||
c.Assert(err, IsNil)
|
||||
lp.PruneColumns(lp.Schema().Columns)
|
||||
solver := &aggPushDownSolver{builder.allocator, builder.ctx}
|
||||
solver := &aggregationOptimizer{builder.allocator, builder.ctx}
|
||||
solver.aggPushDown(lp)
|
||||
lp.ResolveIndicesAndCorCols()
|
||||
info, err := lp.convert2PhysicalPlan(&requiredProperty{})
|
||||
@ -438,12 +438,12 @@ func (s *testPlanSuite) TestCBO(c *C) {
|
||||
best: "Table(t)->HashAgg->Sort + Limit(1) + Offset(0)->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select count(*) from t where concat(a,b) = 'abc' group by a",
|
||||
best: "Table(t)->Selection->StreamAgg",
|
||||
sql: "select count(*) from t where concat(a,b) = 'abc' group by c",
|
||||
best: "Index(t.c_d_e)[[<nil>,+inf]]->Selection->StreamAgg",
|
||||
},
|
||||
{
|
||||
sql: "select count(*) from t where concat(a,b) = 'abc' group by a order by a",
|
||||
best: "Table(t)->Selection->StreamAgg->Projection",
|
||||
best: "Table(t)->Selection->Projection->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select count(distinct e) from t where c = 1 and concat(c,d) = 'abc' group by d",
|
||||
@ -508,7 +508,7 @@ func (s *testPlanSuite) TestCBO(c *C) {
|
||||
},
|
||||
{
|
||||
sql: "select * from (select t.a from t union select t.d from t where t.c = 1 union select t.c from t) k order by a limit 1",
|
||||
best: "UnionAll{Table(t)->HashAgg->Index(t.c_d_e)[[1,1]]->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)",
|
||||
best: "UnionAll{Table(t)->Projection->Index(t.c_d_e)[[1,1]]->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)",
|
||||
},
|
||||
{
|
||||
sql: "select * from (select t.a from t union all select t.d from t where t.c = 1 union all select t.c from t) k order by a limit 1",
|
||||
@ -516,7 +516,11 @@ func (s *testPlanSuite) TestCBO(c *C) {
|
||||
},
|
||||
{
|
||||
sql: "select * from (select t.a from t union select t.d from t union select t.c from t) k order by a limit 1",
|
||||
best: "UnionAll{Table(t)->HashAgg->Table(t)->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)",
|
||||
best: "UnionAll{Table(t)->Projection->Table(t)->HashAgg->Table(t)->HashAgg}->HashAgg->Sort + Limit(1) + Offset(0)",
|
||||
},
|
||||
{
|
||||
sql: "select t.c from t where 0 = (select count(b) from t t1 where t.a = t1.b)",
|
||||
best: "LeftHashJoin{Table(t)->Table(t)->HashAgg}(test.t.a,t1.b)->Projection->Selection->Projection",
|
||||
},
|
||||
}
|
||||
for _, ca := range cases {
|
||||
@ -536,7 +540,7 @@ func (s *testPlanSuite) TestCBO(c *C) {
|
||||
p := builder.build(stmt)
|
||||
c.Assert(builder.err, IsNil)
|
||||
lp := p.(LogicalPlan)
|
||||
lp, err = logicalOptimize(flagPredicatePushDown|flagPrunColumns|flagAggPushDown|flagDecorrelate, lp, builder.ctx, builder.allocator)
|
||||
lp, err = logicalOptimize(flagPredicatePushDown|flagBuildKeyInfo|flagPrunColumns|flagAggregationOptimize|flagDecorrelate, lp, builder.ctx, builder.allocator)
|
||||
lp.ResolveIndicesAndCorCols()
|
||||
info, err := lp.convert2PhysicalPlan(&requiredProperty{})
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
Reference in New Issue
Block a user