diff --git a/planner/core/exhaust_physical_plans.go b/planner/core/exhaust_physical_plans.go index 60d794bd76..be439dbd94 100644 --- a/planner/core/exhaust_physical_plans.go +++ b/planner/core/exhaust_physical_plans.go @@ -2752,11 +2752,37 @@ func (la *LogicalAggregation) exhaustPhysicalPlans(prop *property.PhysicalProper } func (p *LogicalSelection) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]PhysicalPlan, bool, error) { + newProps := make([]*property.PhysicalProperty, 0, 2) childProp := prop.CloneEssentialFields() - sel := PhysicalSelection{ - Conditions: p.Conditions, - }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, childProp) - return []PhysicalPlan{sel}, true, nil + newProps = append(newProps, childProp) + + if prop.TaskTp != property.MppTaskType && + p.SCtx().GetSessionVars().IsMPPAllowed() && + p.canPushDown(kv.TiFlash) { + childPropMpp := prop.CloneEssentialFields() + childPropMpp.TaskTp = property.MppTaskType + newProps = append(newProps, childPropMpp) + } + + ret := make([]PhysicalPlan, 0, len(newProps)) + for _, newProp := range newProps { + sel := PhysicalSelection{ + Conditions: p.Conditions, + }.Init(p.ctx, p.stats.ScaleByExpectCnt(prop.ExpectedCnt), p.blockOffset, newProp) + ret = append(ret, sel) + } + return ret, true, nil +} + +// utility function to check whether we can push down Selection to TiKV or TiFlash +func (p *LogicalSelection) canPushDown(storeTp kv.StoreType) bool { + return !expression.ContainVirtualColumn(p.Conditions) && + p.canPushToCop(storeTp) && + expression.CanExprsPushDown( + p.SCtx().GetSessionVars().StmtCtx, + p.Conditions, + p.SCtx().GetClient(), + storeTp) } func (p *LogicalLimit) exhaustPhysicalPlans(prop *property.PhysicalProperty) ([]PhysicalPlan, bool, error) { diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 0ba8d6d971..b8bdec71a6 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -4418,6 +4418,48 @@ func TestPushDownProjectionForTiFlash(t *testing.T) { } } +func TestPushDownSelectionForMPP(t *testing.T) { + store, clean := testkit.CreateMockStore(t) + defer clean() + tk := testkit.NewTestKit(t, store) + tk.MustExec("use test") + tk.MustExec("drop table if exists t") + tk.MustExec("create table t (id int, value decimal(6,3), name char(128))") + tk.MustExec("analyze table t") + + // Create virtual tiflash replica info. + dom := domain.GetDomain(tk.Session()) + is := dom.InfoSchema() + db, exists := is.SchemaByName(model.NewCIStr("test")) + require.True(t, exists) + for _, tblInfo := range db.Tables { + if tblInfo.Name.L == "t" { + tblInfo.TiFlashReplica = &model.TiFlashReplicaInfo{ + Count: 1, + Available: true, + } + } + } + + tk.MustExec("set @@tidb_allow_mpp=1; set @@tidb_enforce_mpp=1;") + + var input []string + var output []struct { + SQL string + Plan []string + } + integrationSuiteData := core.GetIntegrationSuiteData() + integrationSuiteData.GetTestCases(t, &input, &output) + for i, tt := range input { + testdata.OnRecord(func() { + output[i].SQL = tt + output[i].Plan = testdata.ConvertRowsToStrings(tk.MustQuery(tt).Rows()) + }) + res := tk.MustQuery(tt) + res.Check(testkit.Rows(output[i].Plan...)) + } +} + func TestPushDownProjectionForMPP(t *testing.T) { store, clean := testkit.CreateMockStore(t) defer clean() diff --git a/planner/core/testdata/enforce_mpp_suite_out.json b/planner/core/testdata/enforce_mpp_suite_out.json index 73440140aa..6f45fe67ab 100644 --- a/planner/core/testdata/enforce_mpp_suite_out.json +++ b/planner/core/testdata/enforce_mpp_suite_out.json @@ -592,30 +592,30 @@ { "SQL": "explain select a from t where t.a>1 or t.a in (select a from t); -- 7. left outer semi join", "Plan": [ - "TableReader_48 8000.00 root data:ExchangeSender_47", - "└─ExchangeSender_47 8000.00 mpp[tiflash] ExchangeType: PassThrough", + "TableReader_51 8000.00 root data:ExchangeSender_50", + "└─ExchangeSender_50 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_8 8000.00 mpp[tiflash] test.t.a", - " └─Selection_45 8000.00 mpp[tiflash] or(gt(test.t.a, 1), Column#3)", + " └─Selection_49 8000.00 mpp[tiflash] or(gt(test.t.a, 1), Column#3)", " └─HashJoin_46 10000.00 mpp[tiflash] left outer semi join, equal:[eq(test.t.a, test.t.a)]", - " ├─ExchangeReceiver_26(Build) 10000.00 mpp[tiflash] ", - " │ └─ExchangeSender_25 10000.00 mpp[tiflash] ExchangeType: Broadcast", - " │ └─TableFullScan_24 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", - " └─TableFullScan_23(Probe) 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + " ├─ExchangeReceiver_27(Build) 10000.00 mpp[tiflash] ", + " │ └─ExchangeSender_26 10000.00 mpp[tiflash] ExchangeType: Broadcast", + " │ └─TableFullScan_25 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", + " └─TableFullScan_24(Probe) 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, { "SQL": "explain select a from t where t.a>1 or t.a not in (select a from t); -- now it's supported -- 8. anti left outer semi join", "Plan": [ - "TableReader_48 8000.00 root data:ExchangeSender_47", - "└─ExchangeSender_47 8000.00 mpp[tiflash] ExchangeType: PassThrough", + "TableReader_51 8000.00 root data:ExchangeSender_50", + "└─ExchangeSender_50 8000.00 mpp[tiflash] ExchangeType: PassThrough", " └─Projection_8 8000.00 mpp[tiflash] test.t.a", - " └─Selection_45 8000.00 mpp[tiflash] or(gt(test.t.a, 1), Column#3)", + " └─Selection_49 8000.00 mpp[tiflash] or(gt(test.t.a, 1), Column#3)", " └─HashJoin_46 10000.00 mpp[tiflash] anti left outer semi join, equal:[eq(test.t.a, test.t.a)]", - " ├─ExchangeReceiver_26(Build) 10000.00 mpp[tiflash] ", - " │ └─ExchangeSender_25 10000.00 mpp[tiflash] ExchangeType: Broadcast", - " │ └─TableFullScan_24 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", - " └─TableFullScan_23(Probe) 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + " ├─ExchangeReceiver_27(Build) 10000.00 mpp[tiflash] ", + " │ └─ExchangeSender_26 10000.00 mpp[tiflash] ExchangeType: Broadcast", + " │ └─TableFullScan_25 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo", + " └─TableFullScan_24(Probe) 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" ], "Warn": null }, diff --git a/planner/core/testdata/integration_suite_in.json b/planner/core/testdata/integration_suite_in.json index 9ae38b419b..ac9347d4ac 100644 --- a/planner/core/testdata/integration_suite_in.json +++ b/planner/core/testdata/integration_suite_in.json @@ -776,6 +776,13 @@ "desc format = 'brief' SELECT FROM_UNIXTIME(name,'%Y-%m-%d') FROM t;" ] }, + { + "name": "TestPushDownSelectionForMPP", + "cases": [ + "desc format = 'brief' select /*+ hash_agg()*/ count(*) c, id from t group by id having id >c", + "desc format = 'brief' select * from t where id < 2" + ] + }, { "name": "TestMppUnionAll", "cases": [ diff --git a/planner/core/testdata/integration_suite_out.json b/planner/core/testdata/integration_suite_out.json index 6f1aee83e4..169e84f602 100644 --- a/planner/core/testdata/integration_suite_out.json +++ b/planner/core/testdata/integration_suite_out.json @@ -2561,10 +2561,10 @@ "Plan": [ "HashJoin_19 3.00 127.40 root CARTESIAN left outer semi join", "├─Selection_39(Build) 0.80 11.18 root eq(2, Column#18)", - "│ └─StreamAgg_60 1.00 8.18 root funcs:count(Column#32)->Column#18", - "│ └─TableReader_61 1.00 5.17 root data:StreamAgg_44", - "│ └─StreamAgg_44 1.00 49.50 batchCop[tiflash] funcs:count(1)->Column#32", - "│ └─TableFullScan_59 3.00 40.50 batchCop[tiflash] table:t1 keep order:false", + "│ └─StreamAgg_61 1.00 8.18 root funcs:count(Column#32)->Column#18", + "│ └─TableReader_62 1.00 5.17 root data:StreamAgg_45", + "│ └─StreamAgg_45 1.00 49.50 batchCop[tiflash] funcs:count(1)->Column#32", + "│ └─TableFullScan_60 3.00 40.50 batchCop[tiflash] table:t1 keep order:false", "└─Projection_20(Probe) 3.00 95.82 root 1->Column#28", " └─Apply_22 3.00 76.02 root CARTESIAN left outer join", " ├─TableReader_24(Build) 3.00 10.16 root data:TableFullScan_23", @@ -5221,6 +5221,33 @@ } ] }, + { + "Name": "TestPushDownSelectionForMPP", + "Cases": [ + { + "SQL": "desc format = 'brief' select /*+ hash_agg()*/ count(*) c, id from t group by id having id >c", + "Plan": [ + "TableReader 6400.00 root data:ExchangeSender", + "└─ExchangeSender 6400.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Selection 6400.00 mpp[tiflash] gt(test.t.id, Column#5)", + " └─Projection 8000.00 mpp[tiflash] Column#5, test.t.id", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.id, funcs:count(1)->Column#5, funcs:firstrow(test.t.id)->test.t.id", + " └─ExchangeReceiver 10000.00 mpp[tiflash] ", + " └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.id, collate: binary]", + " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ] + }, + { + "SQL": "desc format = 'brief' select * from t where id < 2", + "Plan": [ + "TableReader 3323.33 root data:ExchangeSender", + "└─ExchangeSender 3323.33 mpp[tiflash] ExchangeType: PassThrough", + " └─Selection 3323.33 mpp[tiflash] lt(test.t.id, 2)", + " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" + ] + } + ] + }, { "Name": "TestMppUnionAll", "Cases": [ @@ -5605,11 +5632,11 @@ "SQL": "desc format = 'brief' select id from t group by id having avg(value)>0", "Plan": [ "Projection 6400.00 root test.t.id", - "└─Selection 6400.00 root gt(Column#4, 0)", - " └─TableReader 8000.00 root data:ExchangeSender", - " └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection 8000.00 mpp[tiflash] div(Column#4, cast(case(eq(Column#9, 0), 1, Column#9), decimal(20,0) BINARY))->Column#4, test.t.id", - " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.id, funcs:count(test.t.value)->Column#9, funcs:sum(test.t.value)->Column#4, funcs:firstrow(test.t.id)->test.t.id", + "└─TableReader 6400.00 root data:ExchangeSender", + " └─ExchangeSender 6400.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Selection 6400.00 mpp[tiflash] gt(Column#4, 0)", + " └─Projection 8000.00 mpp[tiflash] div(Column#4, cast(case(eq(Column#17, 0), 1, Column#17), decimal(20,0) BINARY))->Column#4, test.t.id", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.id, funcs:count(test.t.value)->Column#17, funcs:sum(test.t.value)->Column#4, funcs:firstrow(test.t.id)->test.t.id", " └─ExchangeReceiver 10000.00 mpp[tiflash] ", " └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.id, collate: binary]", " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo" @@ -5618,11 +5645,11 @@ { "SQL": "desc format = 'brief' select avg(value),id from t group by id having avg(value)>0", "Plan": [ - "Selection 6400.00 root gt(Column#4, 0)", - "└─TableReader 8000.00 root data:ExchangeSender", - " └─ExchangeSender 8000.00 mpp[tiflash] ExchangeType: PassThrough", - " └─Projection 8000.00 mpp[tiflash] div(Column#4, cast(case(eq(Column#10, 0), 1, Column#10), decimal(20,0) BINARY))->Column#4, test.t.id", - " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.id, funcs:count(test.t.value)->Column#10, funcs:sum(test.t.value)->Column#4, funcs:firstrow(test.t.id)->test.t.id", + "TableReader 6400.00 root data:ExchangeSender", + "└─ExchangeSender 6400.00 mpp[tiflash] ExchangeType: PassThrough", + " └─Selection 6400.00 mpp[tiflash] gt(Column#4, 0)", + " └─Projection 8000.00 mpp[tiflash] div(Column#4, cast(case(eq(Column#18, 0), 1, Column#18), decimal(20,0) BINARY))->Column#4, test.t.id", + " └─HashAgg 8000.00 mpp[tiflash] group by:test.t.id, funcs:count(test.t.value)->Column#18, funcs:sum(test.t.value)->Column#4, funcs:firstrow(test.t.id)->test.t.id", " └─ExchangeReceiver 10000.00 mpp[tiflash] ", " └─ExchangeSender 10000.00 mpp[tiflash] ExchangeType: HashPartition, Hash Cols: [name: test.t.id, collate: binary]", " └─TableFullScan 10000.00 mpp[tiflash] table:t keep order:false, stats:pseudo"