144 lines
5.5 KiB
Go
144 lines
5.5 KiB
Go
// Copyright 2018 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package core
|
|
|
|
import (
|
|
"math"
|
|
|
|
"github.com/pingcap/parser/ast"
|
|
"github.com/pingcap/parser/mysql"
|
|
"github.com/pingcap/tidb/expression"
|
|
"github.com/pingcap/tidb/expression/aggregation"
|
|
"github.com/pingcap/tidb/sessionctx"
|
|
"github.com/pingcap/tidb/types"
|
|
)
|
|
|
|
type aggregationEliminator struct {
|
|
aggregationEliminateChecker
|
|
}
|
|
|
|
type aggregationEliminateChecker struct {
|
|
}
|
|
|
|
// tryToEliminateAggregation will eliminate aggregation grouped by unique key.
|
|
// e.g. select min(b) from t group by a. If a is a unique key, then this sql is equal to `select b from t group by a`.
|
|
// For count(expr), sum(expr), avg(expr), count(distinct expr, [expr...]) we may need to rewrite the expr. Details are shown below.
|
|
// If we can eliminate agg successful, we return a projection. Else we return a nil pointer.
|
|
func (a *aggregationEliminateChecker) tryToEliminateAggregation(agg *LogicalAggregation) *LogicalProjection {
|
|
schemaByGroupby := expression.NewSchema(agg.groupByCols...)
|
|
coveredByUniqueKey := false
|
|
for _, key := range agg.children[0].Schema().Keys {
|
|
if schemaByGroupby.ColumnsIndices(key) != nil {
|
|
coveredByUniqueKey = true
|
|
break
|
|
}
|
|
}
|
|
if coveredByUniqueKey {
|
|
// GroupByCols has unique key, so this aggregation can be removed.
|
|
proj := a.convertAggToProj(agg)
|
|
proj.SetChildren(agg.children[0])
|
|
return proj
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (a *aggregationEliminateChecker) convertAggToProj(agg *LogicalAggregation) *LogicalProjection {
|
|
proj := LogicalProjection{
|
|
Exprs: make([]expression.Expression, 0, len(agg.AggFuncs)),
|
|
}.Init(agg.ctx)
|
|
for _, fun := range agg.AggFuncs {
|
|
expr := a.rewriteExpr(agg.ctx, fun)
|
|
proj.Exprs = append(proj.Exprs, expr)
|
|
}
|
|
proj.SetSchema(agg.schema.Clone())
|
|
return proj
|
|
}
|
|
|
|
// rewriteExpr will rewrite the aggregate function to expression doesn't contain aggregate function.
|
|
func (a *aggregationEliminateChecker) rewriteExpr(ctx sessionctx.Context, aggFunc *aggregation.AggFuncDesc) expression.Expression {
|
|
switch aggFunc.Name {
|
|
case ast.AggFuncCount:
|
|
if aggFunc.Mode == aggregation.FinalMode {
|
|
return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
|
|
}
|
|
return a.rewriteCount(ctx, aggFunc.Args, aggFunc.RetTp)
|
|
case ast.AggFuncSum, ast.AggFuncAvg, ast.AggFuncFirstRow, ast.AggFuncMax, ast.AggFuncMin, ast.AggFuncGroupConcat:
|
|
return a.wrapCastFunction(ctx, aggFunc.Args[0], aggFunc.RetTp)
|
|
case ast.AggFuncBitAnd, ast.AggFuncBitOr, ast.AggFuncBitXor:
|
|
return a.rewriteBitFunc(ctx, aggFunc.Name, aggFunc.Args[0], aggFunc.RetTp)
|
|
default:
|
|
panic("Unsupported function")
|
|
}
|
|
}
|
|
|
|
func (a *aggregationEliminateChecker) rewriteCount(ctx sessionctx.Context, exprs []expression.Expression, targetTp *types.FieldType) expression.Expression {
|
|
// If is count(expr), we will change it to if(isnull(expr), 0, 1).
|
|
// If is count(distinct x, y, z) we will change it to if(isnull(x) or isnull(y) or isnull(z), 0, 1).
|
|
// If is count(expr not null), we will change it to constant 1.
|
|
isNullExprs := make([]expression.Expression, 0, len(exprs))
|
|
for _, expr := range exprs {
|
|
if mysql.HasNotNullFlag(expr.GetType().Flag) {
|
|
isNullExprs = append(isNullExprs, expression.Zero)
|
|
} else {
|
|
isNullExpr := expression.NewFunctionInternal(ctx, ast.IsNull, types.NewFieldType(mysql.TypeTiny), expr)
|
|
isNullExprs = append(isNullExprs, isNullExpr)
|
|
}
|
|
}
|
|
|
|
innerExpr := expression.ComposeDNFCondition(ctx, isNullExprs...)
|
|
newExpr := expression.NewFunctionInternal(ctx, ast.If, targetTp, innerExpr, expression.Zero, expression.One)
|
|
return newExpr
|
|
}
|
|
|
|
func (a *aggregationEliminateChecker) rewriteBitFunc(ctx sessionctx.Context, funcType string, arg expression.Expression, targetTp *types.FieldType) expression.Expression {
|
|
// For not integer type. We need to cast(cast(arg as signed) as unsigned) to make the bit function work.
|
|
innerCast := expression.WrapWithCastAsInt(ctx, arg)
|
|
outerCast := a.wrapCastFunction(ctx, innerCast, targetTp)
|
|
var finalExpr expression.Expression
|
|
if funcType != ast.AggFuncBitAnd {
|
|
finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, targetTp, outerCast, expression.Zero.Clone())
|
|
} else {
|
|
finalExpr = expression.NewFunctionInternal(ctx, ast.Ifnull, outerCast.GetType(), outerCast, &expression.Constant{Value: types.NewUintDatum(math.MaxUint64), RetType: targetTp})
|
|
}
|
|
return finalExpr
|
|
}
|
|
|
|
// wrapCastFunction will wrap a cast if the targetTp is not equal to the arg's.
|
|
func (a *aggregationEliminateChecker) wrapCastFunction(ctx sessionctx.Context, arg expression.Expression, targetTp *types.FieldType) expression.Expression {
|
|
if arg.GetType() == targetTp {
|
|
return arg
|
|
}
|
|
return expression.BuildCastFunction(ctx, arg, targetTp)
|
|
}
|
|
|
|
func (a *aggregationEliminator) optimize(p LogicalPlan) (LogicalPlan, error) {
|
|
newChildren := make([]LogicalPlan, 0, len(p.Children()))
|
|
for _, child := range p.Children() {
|
|
newChild, err := a.optimize(child)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
newChildren = append(newChildren, newChild)
|
|
}
|
|
p.SetChildren(newChildren...)
|
|
agg, ok := p.(*LogicalAggregation)
|
|
if !ok {
|
|
return p, nil
|
|
}
|
|
if proj := a.tryToEliminateAggregation(agg); proj != nil {
|
|
return proj, nil
|
|
}
|
|
return p, nil
|
|
}
|