planner: fix aggregation hint didn't work in some cases (#11996)

This commit is contained in:
Zijie Zhang
2019-09-04 14:08:35 +08:00
committed by pingcap-github-bot
parent a18ddb811f
commit 0872b65ff1
4 changed files with 70 additions and 33 deletions

View File

@ -441,6 +441,9 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast.
// it will be rewrote to t.id < (select max(s.id) from s).
func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) {
plan4Agg := LogicalAggregation{}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
// Create a "max" or "min" aggregation.
@ -567,6 +570,9 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),
@ -601,6 +607,9 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
plan4Agg := LogicalAggregation{
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
}.Init(er.sctx)
if hint := er.b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.SetChildren(np)
firstRowResultCol := &expression.Column{
ColName: model.NewCIStr("col_firstRow"),

View File

@ -97,6 +97,9 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
b.optFlag = b.optFlag | flagEliminateProjection
plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx)
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...)
// aggIdxMap maps the old index to new index after applying common aggregation functions elimination.
aggIndexMap := make(map[int]int)
@ -149,9 +152,6 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
plan4Agg.GroupByItems = gbyItems
plan4Agg.SetSchema(schema4Agg)
plan4Agg.collectGroupByColumns()
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
return plan4Agg, aggIndexMap, nil
}
@ -794,6 +794,9 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggr
AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()),
GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]),
}.Init(b.ctx)
if hint := b.TableHints(); hint != nil {
plan4Agg.preferAggType = hint.preferAggType
}
plan4Agg.collectGroupByColumns()
for _, col := range child.Schema().Columns {
aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false)

View File

@ -1603,31 +1603,28 @@ func (s *testPlanSuite) TestAggregationHints(c *C) {
sessionVars.HashAggPartialConcurrency = 1
tests := []struct {
sql string
best string
warning string
sql string
best string
warning string
aggPushDown bool
}{
// without Aggregation hints
{
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
warning: "",
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
},
{
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
warning: "",
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
},
// with Aggregation hints
{
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
warning: "",
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
},
{
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
warning: "",
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
},
// test conflict warning
{
@ -1635,24 +1632,50 @@ func (s *testPlanSuite) TestAggregationHints(c *C) {
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
warning: "[planner:1815]Optimizer aggregation hints are conflicted",
},
// additional test
{
sql: "select /*+ STREAM_AGG() */ distinct a from t",
best: "TableReader(Table(t)->StreamAgg)->StreamAgg",
},
{
sql: "select /*+ HASH_AGG() */ t1.a from t t1 where t1.a < any(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t)->HashAgg)->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
},
{
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a != any(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t))->Projection->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
},
{
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a = all(select t2.b from t t2)",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection->HashAgg}->Projection->Projection",
},
{
sql: "select /*+ STREAM_AGG() */ sum(t1.a) from t t1 join t t2 on t1.b = t2.b group by t1.b",
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Sort->Projection->StreamAgg}(test.t2.b,test.t1.b)->HashAgg",
warning: "[planner:1815]Optimizer Hint STREAM_AGG is inapplicable",
aggPushDown: true,
},
}
ctx := context.Background()
for i, test := range tests {
comment := Commentf("case:%v sql:%s", i, test)
se.GetSessionVars().StmtCtx.SetWarnings(nil)
se.GetSessionVars().AllowAggPushDown = test.aggPushDown
stmt, err := s.ParseOneStmt(test.sql, "", "")
c.Assert(err, IsNil, comment)
p, err := planner.Optimize(ctx, se, stmt, s.is)
c.Assert(err, IsNil)
c.Assert(core.ToString(p), Equals, test.best)
c.Assert(core.ToString(p), Equals, test.best, comment)
warnings := se.GetSessionVars().StmtCtx.GetWarnings()
if test.warning == "" {
c.Assert(len(warnings), Equals, 0)
c.Assert(len(warnings), Equals, 0, comment)
} else {
c.Assert(len(warnings), Equals, 1)
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning)
c.Assert(warnings[0].Err.Error(), Equals, test.warning)
c.Assert(len(warnings), Equals, 1, comment)
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment)
c.Assert(warnings[0].Err.Error(), Equals, test.warning, comment)
}
}
}

View File

@ -189,7 +189,7 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a
// tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't
// process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator.
// If the pushed aggregation is grouped by unique key, it's no need to push it down.
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) {
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int, preferAggType uint) (_ LogicalPlan, err error) {
child := join.children[childIdx]
if aggregation.IsAllFirstRow(aggFuncs) {
return child, nil
@ -204,7 +204,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg
return child, nil
}
}
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols)
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, preferAggType)
if err != nil {
return nil, err
}
@ -247,10 +247,11 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation.
return false
}
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) {
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, preferAggType uint) (*LogicalAggregation, error) {
agg := LogicalAggregation{
GroupByItems: expression.Column2Exprs(gbyCols),
groupByCols: gbyCols,
GroupByItems: expression.Column2Exprs(gbyCols),
groupByCols: gbyCols,
preferAggType: preferAggType,
}.Init(ctx)
aggLen := len(aggFuncs) + len(gbyCols)
newAggFuncDescs := make([]*aggregation.AggFuncDesc, 0, aggLen)
@ -282,8 +283,9 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs
func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan {
ctx := agg.ctx
newAgg := LogicalAggregation{
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
preferAggType: agg.preferAggType,
}.Init(ctx)
newAgg.SetSchema(agg.schema.Clone())
for _, aggFunc := range agg.AggFuncs {
@ -340,7 +342,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
if rightInvalid {
rChild = join.children[1]
} else {
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1)
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1, agg.preferAggType)
if err != nil {
return nil, err
}
@ -348,7 +350,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
if leftInvalid {
lChild = join.children[0]
} else {
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0, agg.preferAggType)
if err != nil {
return nil, err
}
@ -380,7 +382,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
} else if union, ok1 := child.(*LogicalUnionAll); ok1 {
var gbyCols []*expression.Column
gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil)
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols)
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols, agg.preferAggType)
if err != nil {
return nil, err
}