diff --git a/cmd/explaintest/r/tpch.result b/cmd/explaintest/r/tpch.result index f7b02fc2f8..9eba9db193 100644 --- a/cmd/explaintest/r/tpch.result +++ b/cmd/explaintest/r/tpch.result @@ -1222,7 +1222,7 @@ id count task operator info Projection_25 1.00 root Column#2, Column#104 └─TopN_28 1.00 root Column#104:desc, Column#2:asc, offset:0, count:100 └─HashAgg_34 1.00 root group by:Column#2, funcs:count(1), firstrow(Column#2) - └─IndexHashJoin_49 7828961.66 root anti semi join, inner:IndexLookUp_39, outer key:Column#8, inner key:Column#71, other cond:ne(Column#73, Column#1), ne(Column#73, Column#10) + └─IndexHashJoin_49 7828961.66 root anti semi join, inner:IndexLookUp_39, outer key:Column#8, inner key:Column#71, other cond:ne(Column#73, Column#10) ├─IndexHashJoin_84 9786202.08 root semi join, inner:IndexLookUp_75, outer key:Column#8, inner key:Column#38, other cond:ne(Column#40, Column#1), ne(Column#40, Column#10) │ ├─IndexMergeJoin_95 12232752.60 root inner join, inner:TableReader_93, outer key:Column#8, inner key:Column#25 │ │ ├─HashRightJoin_101 12232752.60 root inner join, inner:HashRightJoin_114, equal:[eq(Column#1, Column#10)] diff --git a/planner/core/integration_test.go b/planner/core/integration_test.go index 7eeeacc674..2e50806a86 100644 --- a/planner/core/integration_test.go +++ b/planner/core/integration_test.go @@ -171,3 +171,48 @@ func (s *testIntegrationSuite) TestApplyNotNullFlag(c *C) { tk.MustQuery("select IFNULL((select t1.x from t1 where t1.x = t2.x), 'xxx') as col1 from t2").Check(testkit.Rows("xxx")) } + +func (s *testIntegrationSuite) TestAntiJoinConstProp(c *C) { + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + tk.MustExec("use test") + tk.MustExec("drop table if exists t1, t2") + tk.MustExec("create table t1(a int not null, b int not null)") + tk.MustExec("insert into t1 values (1,1)") + tk.MustExec("create table t2(a int not null, b int not null)") + tk.MustExec("insert into t2 values (2,2)") + + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.a = t1.a and t2.a > 1)").Check(testkit.Rows( + "1 1", + )) + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.b = t1.b and t2.a > 1)").Check(testkit.Rows( + "1 1", + )) + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.b = t1.b and t2.b > 1)").Check(testkit.Rows( + "1 1", + )) + tk.MustQuery("select q.a in (select count(*) from t1 s where not exists (select 1 from t1 p where q.a > 1 and p.a = s.a)) from t1 q").Check(testkit.Rows( + "1", + )) + tk.MustQuery("select q.a in (select not exists (select 1 from t1 p where q.a > 1 and p.a = s.a) from t1 s) from t1 q").Check(testkit.Rows( + "1", + )) + + tk.MustExec("drop table t1, t2") + tk.MustExec("create table t1(a int not null, b int)") + tk.MustExec("insert into t1 values (1,null)") + tk.MustExec("create table t2(a int not null, b int)") + tk.MustExec("insert into t2 values (2,2)") + + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t2.b > t1.b)").Check(testkit.Rows( + "1 ", + )) + tk.MustQuery("select * from t1 where t1.a not in (select a from t2 where t1.a = 2)").Check(testkit.Rows( + "1 ", + )) +} diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 8d7bca7cf2..fc3472a41f 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -234,9 +234,14 @@ func (p *LogicalJoin) pushDownConstExpr(expr expression.Expression, leftCond []e } else { leftCond = append(leftCond, expr) } - case SemiJoin, AntiSemiJoin, InnerJoin: + case SemiJoin, InnerJoin: leftCond = append(leftCond, expr) rightCond = append(rightCond, expr) + case AntiSemiJoin: + if filterCond { + leftCond = append(leftCond, expr) + } + rightCond = append(rightCond, expr) } return leftCond, rightCond } @@ -263,18 +268,13 @@ func (p *LogicalJoin) extractOnCondition(conditions []expression.Expression, der arg0, arg1 = arg1, arg0 } if leftCol != nil && rightCol != nil { - // Do not derive `is not null` for anti join, since it may cause wrong results. - // For example: - // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, - // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, - // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, - if deriveLeft && p.JoinType != AntiSemiJoin { + if deriveLeft { if isNullRejected(ctx, left.Schema(), expr) && !mysql.HasNotNullFlag(leftCol.RetType.Flag) { notNullExpr := expression.BuildNotNullExpr(ctx, leftCol) leftCond = append(leftCond, notNullExpr) } } - if deriveRight && p.JoinType != AntiSemiJoin { + if deriveRight { if isNullRejected(ctx, right.Schema(), expr) && !mysql.HasNotNullFlag(rightCol.RetType.Flag) { notNullExpr := expression.BuildNotNullExpr(ctx, rightCol) rightCond = append(rightCond, notNullExpr) diff --git a/planner/core/logical_plan_test.go b/planner/core/logical_plan_test.go index fd281d1d0a..cca9899c24 100644 --- a/planner/core/logical_plan_test.go +++ b/planner/core/logical_plan_test.go @@ -457,7 +457,7 @@ func (s *testPlanSuite) TestAntiSemiJoinConstFalse(c *C) { }{ { sql: "select a from t t1 where not exists (select a from t t2 where t1.a = t2.a and t2.b = 1 and t2.b = 2)", - best: "Join{DataScan(t1)->DataScan(t2)}->Projection", + best: "Join{DataScan(t1)->DataScan(t2)}(Column#1,Column#13)->Projection", joinType: "anti semi join", }, } diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index f01dbdef6c..31317eb66b 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -153,7 +153,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret p.LeftConditions = nil ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) ret = append(ret, leftPushCond...) - case SemiJoin, AntiSemiJoin, InnerJoin: + case SemiJoin, InnerJoin: tempCond := make([]expression.Expression, 0, len(p.LeftConditions)+len(p.RightConditions)+len(p.EqualConditions)+len(p.OtherConditions)+len(predicates)) tempCond = append(tempCond, p.LeftConditions...) tempCond = append(tempCond, p.RightConditions...) @@ -162,13 +162,10 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret tempCond = append(tempCond, predicates...) tempCond = expression.ExtractFiltersFromDNFs(p.ctx, tempCond) tempCond = expression.PropagateConstant(p.ctx, tempCond) - // Return table dual when filter is constant false or null. Not applicable to AntiSemiJoin. - // TODO: For AntiSemiJoin, we can use outer plan to substitute LogicalJoin actually. - if p.JoinType != AntiSemiJoin { - dual := conds2TableDual(p, tempCond) - if dual != nil { - return ret, dual - } + // Return table dual when filter is constant false or null. + dual := conds2TableDual(p, tempCond) + if dual != nil { + return ret, dual } equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(tempCond, true, true) p.LeftConditions = nil @@ -177,6 +174,24 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret p.OtherConditions = otherCond leftCond = leftPushCond rightCond = rightPushCond + case AntiSemiJoin: + predicates = expression.PropagateConstant(p.ctx, predicates) + // Return table dual when filter is constant false or null. + dual := conds2TableDual(p, predicates) + if dual != nil { + return ret, dual + } + // `predicates` should only contain left conditions or constant filters. + _, leftPushCond, rightPushCond, _ = p.extractOnCondition(predicates, true, true) + // Do not derive `is not null` for anti join, since it may cause wrong results. + // For example: + // `select * from t t1 where t1.a not in (select b from t t2)` does not imply `t2.b is not null`, + // `select * from t t1 where t1.a not in (select a from t t2 where t1.b = t2.b` does not imply `t1.b is not null`, + // `select * from t t1 where not exists (select * from t t2 where t2.a = t1.a)` does not imply `t1.a is not null`, + leftCond = leftPushCond + rightCond = append(p.RightConditions, rightPushCond...) + p.RightConditions = nil + } leftCond = expression.RemoveDupExprs(p.ctx, leftCond) rightCond = expression.RemoveDupExprs(p.ctx, rightCond)