planner: fix aggregation hint didn't work in some cases (#11996)
This commit is contained in:
committed by
pingcap-github-bot
parent
a18ddb811f
commit
0872b65ff1
@ -441,6 +441,9 @@ func (er *expressionRewriter) handleCompareSubquery(ctx context.Context, v *ast.
|
||||
// it will be rewrote to t.id < (select max(s.id) from s).
|
||||
func (er *expressionRewriter) handleOtherComparableSubq(lexpr, rexpr expression.Expression, np LogicalPlan, useMin bool, cmpFunc string, all bool) {
|
||||
plan4Agg := LogicalAggregation{}.Init(er.sctx)
|
||||
if hint := er.b.TableHints(); hint != nil {
|
||||
plan4Agg.preferAggType = hint.preferAggType
|
||||
}
|
||||
plan4Agg.SetChildren(np)
|
||||
|
||||
// Create a "max" or "min" aggregation.
|
||||
@ -567,6 +570,9 @@ func (er *expressionRewriter) handleNEAny(lexpr, rexpr expression.Expression, np
|
||||
plan4Agg := LogicalAggregation{
|
||||
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
|
||||
}.Init(er.sctx)
|
||||
if hint := er.b.TableHints(); hint != nil {
|
||||
plan4Agg.preferAggType = hint.preferAggType
|
||||
}
|
||||
plan4Agg.SetChildren(np)
|
||||
firstRowResultCol := &expression.Column{
|
||||
ColName: model.NewCIStr("col_firstRow"),
|
||||
@ -601,6 +607,9 @@ func (er *expressionRewriter) handleEQAll(lexpr, rexpr expression.Expression, np
|
||||
plan4Agg := LogicalAggregation{
|
||||
AggFuncs: []*aggregation.AggFuncDesc{firstRowFunc, countFunc},
|
||||
}.Init(er.sctx)
|
||||
if hint := er.b.TableHints(); hint != nil {
|
||||
plan4Agg.preferAggType = hint.preferAggType
|
||||
}
|
||||
plan4Agg.SetChildren(np)
|
||||
firstRowResultCol := &expression.Column{
|
||||
ColName: model.NewCIStr("col_firstRow"),
|
||||
|
||||
@ -97,6 +97,9 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
|
||||
b.optFlag = b.optFlag | flagEliminateProjection
|
||||
|
||||
plan4Agg := LogicalAggregation{AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(aggFuncList))}.Init(b.ctx)
|
||||
if hint := b.TableHints(); hint != nil {
|
||||
plan4Agg.preferAggType = hint.preferAggType
|
||||
}
|
||||
schema4Agg := expression.NewSchema(make([]*expression.Column, 0, len(aggFuncList)+p.Schema().Len())...)
|
||||
// aggIdxMap maps the old index to new index after applying common aggregation functions elimination.
|
||||
aggIndexMap := make(map[int]int)
|
||||
@ -149,9 +152,6 @@ func (b *PlanBuilder) buildAggregation(ctx context.Context, p LogicalPlan, aggFu
|
||||
plan4Agg.GroupByItems = gbyItems
|
||||
plan4Agg.SetSchema(schema4Agg)
|
||||
plan4Agg.collectGroupByColumns()
|
||||
if hint := b.TableHints(); hint != nil {
|
||||
plan4Agg.preferAggType = hint.preferAggType
|
||||
}
|
||||
return plan4Agg, aggIndexMap, nil
|
||||
}
|
||||
|
||||
@ -794,6 +794,9 @@ func (b *PlanBuilder) buildDistinct(child LogicalPlan, length int) (*LogicalAggr
|
||||
AggFuncs: make([]*aggregation.AggFuncDesc, 0, child.Schema().Len()),
|
||||
GroupByItems: expression.Column2Exprs(child.Schema().Clone().Columns[:length]),
|
||||
}.Init(b.ctx)
|
||||
if hint := b.TableHints(); hint != nil {
|
||||
plan4Agg.preferAggType = hint.preferAggType
|
||||
}
|
||||
plan4Agg.collectGroupByColumns()
|
||||
for _, col := range child.Schema().Columns {
|
||||
aggDesc, err := aggregation.NewAggFuncDesc(b.ctx, ast.AggFuncFirstRow, []expression.Expression{col}, false)
|
||||
|
||||
@ -1603,31 +1603,28 @@ func (s *testPlanSuite) TestAggregationHints(c *C) {
|
||||
sessionVars.HashAggPartialConcurrency = 1
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
best string
|
||||
warning string
|
||||
sql string
|
||||
best string
|
||||
warning string
|
||||
aggPushDown bool
|
||||
}{
|
||||
// without Aggregation hints
|
||||
{
|
||||
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
|
||||
warning: "",
|
||||
sql: "select count(*) from t t1, t t2 where t1.a = t2.b",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
|
||||
},
|
||||
{
|
||||
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
|
||||
warning: "",
|
||||
sql: "select count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->HashAgg",
|
||||
},
|
||||
// with Aggregation hints
|
||||
{
|
||||
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
|
||||
warning: "",
|
||||
sql: "select /*+ HASH_AGG() */ count(*) from t t1, t t2 where t1.a = t2.b",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->HashAgg",
|
||||
},
|
||||
{
|
||||
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
|
||||
warning: "",
|
||||
sql: "select /*+ STREAM_AGG() */ count(t1.a) from t t1, t t2 where t1.a = t2.a*2 group by t1.a",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection}(test.t1.a,mul(test.t2.a, 2))->Sort->StreamAgg",
|
||||
},
|
||||
// test conflict warning
|
||||
{
|
||||
@ -1635,24 +1632,50 @@ func (s *testPlanSuite) TestAggregationHints(c *C) {
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))}(test.t1.a,test.t2.b)->StreamAgg",
|
||||
warning: "[planner:1815]Optimizer aggregation hints are conflicted",
|
||||
},
|
||||
// additional test
|
||||
{
|
||||
sql: "select /*+ STREAM_AGG() */ distinct a from t",
|
||||
best: "TableReader(Table(t)->StreamAgg)->StreamAgg",
|
||||
},
|
||||
{
|
||||
sql: "select /*+ HASH_AGG() */ t1.a from t t1 where t1.a < any(select t2.b from t t2)",
|
||||
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t)->HashAgg)->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a != any(select t2.b from t t2)",
|
||||
best: "LeftHashJoin{TableReader(Table(t)->Sel([if(isnull(test.t1.a), <nil>, 1)]))->TableReader(Table(t))->Projection->HashAgg->Sel([ne(agg_col_cnt, 0)])}->Projection->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select /*+ hash_agg() */ t1.a from t t1 where t1.a = all(select t2.b from t t2)",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Projection->HashAgg}->Projection->Projection",
|
||||
},
|
||||
{
|
||||
sql: "select /*+ STREAM_AGG() */ sum(t1.a) from t t1 join t t2 on t1.b = t2.b group by t1.b",
|
||||
best: "LeftHashJoin{TableReader(Table(t))->TableReader(Table(t))->Sort->Projection->StreamAgg}(test.t2.b,test.t1.b)->HashAgg",
|
||||
warning: "[planner:1815]Optimizer Hint STREAM_AGG is inapplicable",
|
||||
aggPushDown: true,
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
for i, test := range tests {
|
||||
comment := Commentf("case:%v sql:%s", i, test)
|
||||
se.GetSessionVars().StmtCtx.SetWarnings(nil)
|
||||
se.GetSessionVars().AllowAggPushDown = test.aggPushDown
|
||||
|
||||
stmt, err := s.ParseOneStmt(test.sql, "", "")
|
||||
c.Assert(err, IsNil, comment)
|
||||
|
||||
p, err := planner.Optimize(ctx, se, stmt, s.is)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(core.ToString(p), Equals, test.best)
|
||||
c.Assert(core.ToString(p), Equals, test.best, comment)
|
||||
|
||||
warnings := se.GetSessionVars().StmtCtx.GetWarnings()
|
||||
if test.warning == "" {
|
||||
c.Assert(len(warnings), Equals, 0)
|
||||
c.Assert(len(warnings), Equals, 0, comment)
|
||||
} else {
|
||||
c.Assert(len(warnings), Equals, 1)
|
||||
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning)
|
||||
c.Assert(warnings[0].Err.Error(), Equals, test.warning)
|
||||
c.Assert(len(warnings), Equals, 1, comment)
|
||||
c.Assert(warnings[0].Level, Equals, stmtctx.WarnLevelWarning, comment)
|
||||
c.Assert(warnings[0].Err.Error(), Equals, test.warning, comment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -189,7 +189,7 @@ func (a *aggregationPushDownSolver) decompose(ctx sessionctx.Context, aggFunc *a
|
||||
// tryToPushDownAgg tries to push down an aggregate function into a join path. If all aggFuncs are first row, we won't
|
||||
// process it temporarily. If not, We will add additional group by columns and first row functions. We make a new aggregation operator.
|
||||
// If the pushed aggregation is grouped by unique key, it's no need to push it down.
|
||||
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int) (_ LogicalPlan, err error) {
|
||||
func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, join *LogicalJoin, childIdx int, preferAggType uint) (_ LogicalPlan, err error) {
|
||||
child := join.children[childIdx]
|
||||
if aggregation.IsAllFirstRow(aggFuncs) {
|
||||
return child, nil
|
||||
@ -204,7 +204,7 @@ func (a *aggregationPushDownSolver) tryToPushDownAgg(aggFuncs []*aggregation.Agg
|
||||
return child, nil
|
||||
}
|
||||
}
|
||||
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols)
|
||||
agg, err := a.makeNewAgg(join.ctx, aggFuncs, gbyCols, preferAggType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -247,10 +247,11 @@ func (a *aggregationPushDownSolver) checkAnyCountAndSum(aggFuncs []*aggregation.
|
||||
return false
|
||||
}
|
||||
|
||||
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column) (*LogicalAggregation, error) {
|
||||
func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs []*aggregation.AggFuncDesc, gbyCols []*expression.Column, preferAggType uint) (*LogicalAggregation, error) {
|
||||
agg := LogicalAggregation{
|
||||
GroupByItems: expression.Column2Exprs(gbyCols),
|
||||
groupByCols: gbyCols,
|
||||
GroupByItems: expression.Column2Exprs(gbyCols),
|
||||
groupByCols: gbyCols,
|
||||
preferAggType: preferAggType,
|
||||
}.Init(ctx)
|
||||
aggLen := len(aggFuncs) + len(gbyCols)
|
||||
newAggFuncDescs := make([]*aggregation.AggFuncDesc, 0, aggLen)
|
||||
@ -282,8 +283,9 @@ func (a *aggregationPushDownSolver) makeNewAgg(ctx sessionctx.Context, aggFuncs
|
||||
func (a *aggregationPushDownSolver) pushAggCrossUnion(agg *LogicalAggregation, unionSchema *expression.Schema, unionChild LogicalPlan) LogicalPlan {
|
||||
ctx := agg.ctx
|
||||
newAgg := LogicalAggregation{
|
||||
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
|
||||
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
|
||||
AggFuncs: make([]*aggregation.AggFuncDesc, 0, len(agg.AggFuncs)),
|
||||
GroupByItems: make([]expression.Expression, 0, len(agg.GroupByItems)),
|
||||
preferAggType: agg.preferAggType,
|
||||
}.Init(ctx)
|
||||
newAgg.SetSchema(agg.schema.Clone())
|
||||
for _, aggFunc := range agg.AggFuncs {
|
||||
@ -340,7 +342,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
|
||||
if rightInvalid {
|
||||
rChild = join.children[1]
|
||||
} else {
|
||||
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1)
|
||||
rChild, err = a.tryToPushDownAgg(rightAggFuncs, rightGbyCols, join, 1, agg.preferAggType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -348,7 +350,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
|
||||
if leftInvalid {
|
||||
lChild = join.children[0]
|
||||
} else {
|
||||
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0)
|
||||
lChild, err = a.tryToPushDownAgg(leftAggFuncs, leftGbyCols, join, 0, agg.preferAggType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -380,7 +382,7 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e
|
||||
} else if union, ok1 := child.(*LogicalUnionAll); ok1 {
|
||||
var gbyCols []*expression.Column
|
||||
gbyCols = expression.ExtractColumnsFromExpressions(gbyCols, agg.GroupByItems, nil)
|
||||
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols)
|
||||
pushedAgg, err := a.makeNewAgg(agg.ctx, agg.AggFuncs, gbyCols, agg.preferAggType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user