From 071caf940e19bbdf1f1482bfe096ef6ae2e5a7a6 Mon Sep 17 00:00:00 2001 From: Han Fei Date: Fri, 2 Dec 2016 22:56:25 +0800 Subject: [PATCH] plan: support pushing agg across projection. (#2156) --- plan/aggregation_push_down.go | 17 +++++++++++++++++ plan/logical_plan_test.go | 6 +++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/plan/aggregation_push_down.go b/plan/aggregation_push_down.go index 86d04db716..98735947e6 100644 --- a/plan/aggregation_push_down.go +++ b/plan/aggregation_push_down.go @@ -327,6 +327,23 @@ func (a *aggPushDownSolver) aggPushDown(p LogicalPlan) { rChild.SetParents(join) join.SetSchema(append(lChild.GetSchema().Clone(), rChild.GetSchema().Clone()...)) } + } else if proj, ok1 := child.(*Projection); ok1 { + // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, + // so we must do this optimization. + for i, gbyItem := range agg.GroupByItems { + agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) + } + agg.collectGroupByColumns() + for _, aggFunc := range agg.AggFuncs { + newArgs := make([]expression.Expression, 0, len(aggFunc.GetArgs())) + for _, arg := range aggFunc.GetArgs() { + newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + } + aggFunc.SetArgs(newArgs) + } + projChild := proj.children[0] + agg.SetChildren(projChild) + projChild.SetParents(agg) } else if union, ok1 := child.(*Union); ok1 { pushedAgg := a.makeNewAgg(agg.AggFuncs, agg.groupByCols) newChildren := make([]Plan, 0, len(union.children)) diff --git a/plan/logical_plan_test.go b/plan/logical_plan_test.go index fa7c7573ba..87ee8ddca6 100644 --- a/plan/logical_plan_test.go +++ b/plan/logical_plan_test.go @@ -675,9 +675,13 @@ func (s *testPlanSuite) TestAggPushDown(c *C) { sql: "select sum(a.a) from t a right join t b on a.c = b.c", best: "Join{DataScan(a)->Aggr(sum(a.a),firstrow(a.c))->DataScan(b)}(a.c,b.c)->Aggr(sum(join_agg_0))->Projection", }, + { + sql: "select sum(a) from (select * from t) x", + best: "DataScan(t)->Aggr(sum(test.t.a))->Projection", + }, { sql: "select sum(c1) from (select c c1, d c2 from t a union all select a c1, b c2 from t b union all select b c1, e c2 from t c) x group by c2", - best: "UnionAll{DataScan(a)->Projection->Aggr(sum(c1),firstrow(c2))->DataScan(b)->Projection->Aggr(sum(c1),firstrow(c2))->DataScan(c)->Projection->Aggr(sum(c1),firstrow(c2))}->Aggr(sum(join_agg_0))->Projection", + best: "UnionAll{DataScan(a)->Aggr(sum(a.c),firstrow(a.d))->DataScan(b)->Aggr(sum(b.a),firstrow(b.b))->DataScan(c)->Aggr(sum(c.b),firstrow(c.e))}->Aggr(sum(join_agg_0))->Projection", }, } for _, ca := range cases {