Files
tidb/planner/core/rule_aggregation_elimination.go

144 lines
5.5 KiB
Go

// Copyright 2018 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package core
import (
"math"
"github.com/pingcap/parser/ast"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
)
type aggregationEliminator struct {
aggregationEliminateChecker
}
type aggregationEliminateChecker struct {
}
// tryToEliminateAggregation will eliminate aggregation grouped by unique key.
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
// If we can eliminate agg successful, we return a projection. Else we return a nil pointer.
func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection {
schemaByGroupby := expression.NewSchema(agg.groupByCols...)
coveredByUniqueKey := false
for _, key := range agg.children[0].Schema().Keys {
if schemaByGroupby.ColumnsIndices(key) != nil {
coveredByUniqueKey = true
break
}
}
if coveredByUniqueKey {
// GroupByCols has unique key, so this aggregation can be removed.
proj := a.convertAggToProj(agg)
proj.SetChildren(agg.children[0])
return proj
}
return nil
}
func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) *LogicalProjection {
proj := LogicalProjection{
Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)),
}.Init(agg.ctx)
for _, fun := range agg.AggFuncs {
expr := a.rewriteExpr(agg.ctx, fun)
proj.Exprs = append(proj.Exprs, expr)
}
proj.SetSchema(agg.schema.Clone())
return proj
}
// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function.
func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression {
switch aggFunc.Name {
case ast.AggFuncCount:
if aggFunc.Mode == aggregation.FinalMode {
return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
}
return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp)
case ast.AggFuncSum, ast.AggFuncAvg, ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncGroupConcat:
return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp)
default:
panic("Unsupported function")
}
}
func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression {
// If is count(expr), we will change it to if(isnull(expr), 0, 1).
// If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1).
// If is count(expr not null), we will change it to constant 1.
isNullExprs := make([]expression.Expression, 0, len(exprs))
for _, expr := range exprs {
if mysql.HasNotNullFlag(expr.GetType().Flag) {
isNullExprs = append(isNullExprs, expression.Zero)
} else {
isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr)
isNullExprs = append(isNullExprs, isNullExpr)
}
}
innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...)
newExpr := expression.NewFunctionInternal(ctx, ast.If, targetTp, innerExpr, expression.Zero, expression.One)
return newExpr
}
func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression {
// For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work.
innerCast := expression.WrapWithCastAsInt(ctx, arg)
outerCast := a.wrapCastFunction(ctx, innerCast, targetTp)
var finalExpr expression.Expression
if funcType != ast.AggFuncBitAnd {
finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, targetTp, outerCast, expression.Zero.Clone())
} else {
finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, outerCast.GetType(), outerCast, &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: targetTp})
}
return finalExpr
}
// wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's.
func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression {
if arg.GetType() == targetTp {
return arg
}
return expression.BuildCastFunction(ctx, arg, targetTp)
}
func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) {
newChildren := make([]LogicalPlan, 0, len(p.Children()))
for _, child := range p.Children() {
newChild, err := a.optimize(child)
if err != nil {
return nil, err
}
newChildren = append(newChildren, newChild)
}
p.SetChildren(newChildren...)
agg, ok := p.(*LogicalAggregation)
if !ok {
return p, nil
}
if proj := a.tryToEliminateAggregation(agg); proj != nil {
return proj, nil
}
return p, nil
}