From 8dcebd12395d3d0dda030892fe751dfa5fd56de3 Mon Sep 17 00:00:00 2001 From: Zhou Kunqin <25057648+time-and-fate@users.noreply.github.com> Date: Fri, 27 Aug 2021 16:46:05 +0800 Subject: [PATCH] planner, expression: avoid exprs with side effects in column pruning and agg pushdown (#27370) --- planner/core/integration_test.go | 28 +++++++ planner/core/rule_aggregation_push_down.go | 47 ++++++++---- planner/core/rule_column_pruning.go | 4 +- .../core/testdata/integration_suite_in.json | 13 ++++ .../core/testdata/integration_suite_out.json | 76 +++++++++++++++++++ 5 files changed, 152 insertions(+), 16 deletions(-) diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index e8e6d08628..a6c12579f0 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -4307,3 +4307,31 @@ func (s *testIntegrationSerialSuite) TestTemporaryTableForCte(c *C) { rows = tk.MustQuery("WITH RECURSIVE cte(a) AS (SELECT 1 UNION SELECT a+1 FROM tmp1 WHERE a < 5) SELECT * FROM cte order by a;") rows.Check(testkit.Rows("1", "2", "3", "4", "5")) } + +func (s *testIntegrationSuite) TestGroupBySetVar(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t1") + tk.MustExec("create table t1(c1 int);") + tk.MustExec("insert into t1 values(1), (2), (3), (4), (5), (6);") + rows := tk.MustQuery("select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;") + rows.Check(testkit.Rows("0 2", "1 2", "2 2")) + + tk.MustExec("create table ta(a int, b int);") + tk.MustExec("set sql_mode='';") + + var input []string + var output []struct { + SQL string + Plan []string + } + s.testData.GetTestCases(c, &input, &output) + for i, tt := range input { + res := tk.MustQuery("explain format = 'brief' " + tt) + s.testData.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = s.testData.ConvertRowsToStrings(res.Rows()) + }) + res.Check(testkit.Rows(output[i].Plan...)) + } +} diff --git a/planner/core/rule_aggregation_push_down.go b/planner/core/rule_aggregation_push_down.go index 4ad2e6f260..0dc84b07be 100644 --- a/planner/core/rule_aggregation_push_down.go +++ b/planner/core/rule_aggregation_push_down.go @@ -428,22 +428,41 @@ func (a *aggregationPushDownSolver) aggPushDown(p LogicalPlan) (_ LogicalPlan, e } else if proj, ok1 := child.(*LogicalProjection); ok1 { // TODO: This optimization is not always reasonable. We have not supported pushing projection to kv layer yet, // so we must do this optimization. - for i, gbyItem := range agg.GroupByItems { - agg.GroupByItems[i] = expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs) - } - for _, aggFunc := range agg.AggFuncs { - newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) - for _, arg := range aggFunc.Args { - newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + noSideEffects := true + newGbyItems := make([]expression.Expression, 0, len(agg.GroupByItems)) + for _, gbyItem := range agg.GroupByItems { + newGbyItems = append(newGbyItems, expression.ColumnSubstitute(gbyItem, proj.schema, proj.Exprs)) + if ExprsHasSideEffects(newGbyItems) { + noSideEffects = false + break } - aggFunc.Args = newArgs } - projChild := proj.children[0] - agg.SetChildren(projChild) - // When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first. - // And then push the new 'Aggregation' below the 'Union All' . - // The final plan tree should be 'Aggregation->Union All->Aggregation->X'. - child = projChild + newAggFuncsArgs := make([][]expression.Expression, 0, len(agg.AggFuncs)) + if noSideEffects { + for _, aggFunc := range agg.AggFuncs { + newArgs := make([]expression.Expression, 0, len(aggFunc.Args)) + for _, arg := range aggFunc.Args { + newArgs = append(newArgs, expression.ColumnSubstitute(arg, proj.schema, proj.Exprs)) + } + if ExprsHasSideEffects(newArgs) { + noSideEffects = false + break + } + newAggFuncsArgs = append(newAggFuncsArgs, newArgs) + } + } + if noSideEffects { + agg.GroupByItems = newGbyItems + for i, aggFunc := range agg.AggFuncs { + aggFunc.Args = newAggFuncsArgs[i] + } + projChild := proj.children[0] + agg.SetChildren(projChild) + // When the origin plan tree is `Aggregation->Projection->Union All->X`, we need to merge 'Aggregation' and 'Projection' first. + // And then push the new 'Aggregation' below the 'Union All' . + // The final plan tree should be 'Aggregation->Union All->Aggregation->X'. + child = projChild + } } if union, ok1 := child.(*LogicalUnionAll); ok1 && p.SCtx().GetSessionVars().AllowAggPushDown { err := a.tryAggPushDownForUnion(union, agg) diff --git a/planner/core/rule_column_pruning.go b/planner/core/rule_column_pruning.go index 77154520bc..46ecf884d5 100644 --- a/planner/core/rule_column_pruning.go +++ b/planner/core/rule_column_pruning.go @@ -96,7 +96,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if la.AggFuncs[i].Name != ast.AggFuncFirstRow { allFirstRow = false } - if !used[i] { + if !used[i] && !ExprsHasSideEffects(la.AggFuncs[i].Args) { la.schema.Columns = append(la.schema.Columns[:i], la.schema.Columns[i+1:]...) la.AggFuncs = append(la.AggFuncs[:i], la.AggFuncs[i+1:]...) } else if la.AggFuncs[i].Name != ast.AggFuncFirstRow { @@ -137,7 +137,7 @@ func (la *LogicalAggregation) PruneColumns(parentUsedCols []*expression.Column) if len(la.GroupByItems) > 0 { for i := len(la.GroupByItems) - 1; i >= 0; i-- { cols := expression.ExtractColumns(la.GroupByItems[i]) - if len(cols) == 0 { + if len(cols) == 0 && !exprHasSetVarOrSleep(la.GroupByItems[i]) { la.GroupByItems = append(la.GroupByItems[:i], la.GroupByItems[i+1:]...) } else { selfUsedCols = append(selfUsedCols, cols...) diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 10dfddf8c8..a8b601d99f 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -355,5 +355,18 @@ "cases": [ "select * from t use index (idx_b) where b = 2 limit 1" ] + }, + { + "name": "TestGroupBySetVar", + "cases": [ + "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;", + // TODO: fix these two cases + "select @n:=@n+1 as e from ta group by e", + "select @n:=@n+a as e from ta group by e", + "select * from (select @n:=@n+1 as e from ta) tt group by e", + "select * from (select @n:=@n+a as e from ta) tt group by e", + "select a from ta group by @n:=@n+1", + "select a from ta group by @n:=@n+a" + ] } ] diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 261b35bc71..9d8c313a49 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -1897,5 +1897,81 @@ ] } ] + }, + { + "Name": "TestGroupBySetVar", + "Cases": [ + { + "SQL": "select floor(dt.rn/2) rownum, count(c1) from (select @rownum := @rownum + 1 rn, c1 from (select @rownum := -1) drn, t1) dt group by floor(dt.rn/2) order by rownum;", + "Plan": [ + "Sort 1.00 root Column#6", + "└─Projection 1.00 root floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#6, Column#5", + " └─HashAgg 1.00 root group by:Column#13, funcs:count(Column#11)->Column#5, funcs:firstrow(Column#12)->Column#4", + " └─Projection 10000.00 root test.t1.c1, Column#4, floor(div(cast(Column#4, decimal(20,0) BINARY), 2))->Column#13", + " └─Projection 10000.00 root setvar(rownum, plus(getvar(rownum), 1))->Column#4, test.t1.c1", + " └─HashJoin 10000.00 root CARTESIAN inner join", + " ├─Projection(Build) 1.00 root setvar(rownum, -1)->Column#1", + " │ └─TableDual 1.00 root rows:1", + " └─TableReader(Probe) 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:t1 keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select @n:=@n+1 as e from ta group by e", + "Plan": [ + "Projection 1.00 root setvar(n, plus(getvar(n), 1))->Column#4", + "└─HashAgg 1.00 root group by:Column#8, funcs:firstrow(1)->Column#7", + " └─Projection 10000.00 root setvar(n, plus(cast(getvar(n), double BINARY), 1))->Column#8", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select @n:=@n+a as e from ta group by e", + "Plan": [ + "Projection 8000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4", + "└─HashAgg 8000.00 root group by:Column#7, funcs:firstrow(Column#6)->test.ta.a", + " └─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#7", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select * from (select @n:=@n+1 as e from ta) tt group by e", + "Plan": [ + "HashAgg 1.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Projection 10000.00 root setvar(n, plus(getvar(n), 1))->Column#4", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select * from (select @n:=@n+a as e from ta) tt group by e", + "Plan": [ + "HashAgg 8000.00 root group by:Column#4, funcs:firstrow(Column#4)->Column#4", + "└─Projection 10000.00 root setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#4", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from ta group by @n:=@n+1", + "Plan": [ + "HashAgg 1.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a", + "└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), 1))->Column#5", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + }, + { + "SQL": "select a from ta group by @n:=@n+a", + "Plan": [ + "HashAgg 8000.00 root group by:Column#5, funcs:firstrow(Column#4)->test.ta.a", + "└─Projection 10000.00 root test.ta.a, setvar(n, plus(getvar(n), cast(test.ta.a, double BINARY)))->Column#5", + " └─TableReader 10000.00 root data:TableFullScan", + " └─TableFullScan 10000.00 cop[tikv] table:ta keep order:false, stats:pseudo" + ] + } + ] } ]