diff --git a/planner/cascades/implementation_rules.go b/planner/cascades/implementation_rules.go index aa1cc94af2..082198b78f 100644 --- a/planner/cascades/implementation_rules.go +++ b/planner/cascades/implementation_rules.go @@ -409,12 +409,11 @@ type ImplHashJoinBuildLeft struct { // Match implements ImplementationRule Match interface. func (r *ImplHashJoinBuildLeft) Match(expr *memo.GroupExpr, prop *property.PhysicalProperty) (matched bool) { - join := expr.ExprNode.(*plannercore.LogicalJoin) - switch join.JoinType { - case plannercore.SemiJoin, plannercore.AntiSemiJoin, plannercore.LeftOuterSemiJoin, plannercore.AntiLeftOuterSemiJoin: - return false - default: + switch expr.ExprNode.(*plannercore.LogicalJoin).JoinType { + case plannercore.InnerJoin, plannercore.LeftOuterJoin, plannercore.RightOuterJoin: return prop.IsEmpty() + default: + return false } } @@ -424,8 +423,11 @@ func (r *ImplHashJoinBuildLeft) OnImplement(expr *memo.GroupExpr, reqProp *prope switch join.JoinType { case plannercore.InnerJoin: return getImplForHashJoin(expr, reqProp, 0, false), nil + case plannercore.LeftOuterJoin: + return getImplForHashJoin(expr, reqProp, 1, true), nil + case plannercore.RightOuterJoin: + return getImplForHashJoin(expr, reqProp, 0, false), nil default: - // TODO: deal with other join type. return nil, nil } } @@ -443,12 +445,17 @@ func (r *ImplHashJoinBuildRight) Match(expr *memo.GroupExpr, prop *property.Phys func (r *ImplHashJoinBuildRight) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalProperty) (memo.Implementation, error) { join := expr.ExprNode.(*plannercore.LogicalJoin) switch join.JoinType { + case plannercore.SemiJoin, plannercore.AntiSemiJoin, + plannercore.LeftOuterSemiJoin, plannercore.AntiLeftOuterSemiJoin: + return getImplForHashJoin(expr, reqProp, 1, false), nil case plannercore.InnerJoin: return getImplForHashJoin(expr, reqProp, 1, false), nil - default: - // TODO: deal with other join type. - return nil, nil + case plannercore.LeftOuterJoin: + return getImplForHashJoin(expr, reqProp, 1, false), nil + case plannercore.RightOuterJoin: + return getImplForHashJoin(expr, reqProp, 0, true), nil } + return nil, nil } // ImplUnionAll implements LogicalUnionAll to PhysicalUnionAll. diff --git a/planner/cascades/integration_test.go b/planner/cascades/integration_test.go index 3561179606..197a92977f 100644 --- a/planner/cascades/integration_test.go +++ b/planner/cascades/integration_test.go @@ -202,7 +202,7 @@ func (s *testIntegrationSuite) TestJoin(c *C) { tk.MustExec("create table t1(a int primary key, b int)") tk.MustExec("create table t2(a int primary key, b int)") tk.MustExec("insert into t1 values (1, 11), (4, 44), (2, 22), (3, 33)") - tk.MustExec("insert into t2 values (1, 111), (2, 222), (3, 333)") + tk.MustExec("insert into t2 values (1, 111), (2, 222), (3, 333), (5, 555)") tk.MustExec("set session tidb_enable_cascades_planner = 1") var input []string var output []struct { diff --git a/planner/cascades/testdata/integration_suite_in.json b/planner/cascades/testdata/integration_suite_in.json index 8fb8a8e98c..5f1869a0bb 100644 --- a/planner/cascades/testdata/integration_suite_in.json +++ b/planner/cascades/testdata/integration_suite_in.json @@ -58,7 +58,9 @@ "name": "TestJoin", "cases": [ "select t1.a, t1.b from t1, t2 where t1.a = t2.a and t1.a > 2", - "select t1.a, t1.b from t1, t2 where t1.a > t2.a and t2.b > 200" + "select t1.a, t1.b from t1, t2 where t1.a > t2.a and t2.b > 200", + "select t1.a, t1.b from t1 left join t2 on t1.a = t2.a where t1.a > 2 and t2.b > 200", + "select t2.a, t2.b from t1 right join t2 on t1.a = t2.a where t1.a > 2 and t2.b > 200" ] }, { diff --git a/planner/cascades/testdata/integration_suite_out.json b/planner/cascades/testdata/integration_suite_out.json index ed45f281e2..cc1a036721 100644 --- a/planner/cascades/testdata/integration_suite_out.json +++ b/planner/cascades/testdata/integration_suite_out.json @@ -452,6 +452,37 @@ "4 44", "4 44" ] + }, + { + "SQL": "select t1.a, t1.b from t1 left join t2 on t1.a = t2.a where t1.a > 2 and t2.b > 200", + "Plan": [ + "Projection_17 3333.33 root test.t1.a, test.t1.b", + "└─Selection_18 3333.33 root gt(test.t2.b, 200)", + " └─HashLeftJoin_20 4166.67 root left outer join, inner:TableReader_23, equal:[eq(test.t1.a, test.t2.a)]", + " ├─TableReader_21 3333.33 root data:TableScan_22", + " │ └─TableScan_22 3333.33 cop[tikv] table:t1, range:(2,+inf], keep order:false, stats:pseudo", + " └─TableReader_23 3333.33 root data:TableScan_24", + " └─TableScan_24 3333.33 cop[tikv] table:t2, range:(2,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "3 33" + ] + }, + { + "SQL": "select t2.a, t2.b from t1 right join t2 on t1.a = t2.a where t1.a > 2 and t2.b > 200", + "Plan": [ + "Projection_13 8000.00 root test.t2.a, test.t2.b", + "└─Selection_14 8000.00 root gt(test.t1.a, 2)", + " └─HashRightJoin_16 10000.00 root right outer join, inner:TableReader_19 (REVERSED), equal:[eq(test.t1.a, test.t2.a)]", + " ├─TableReader_17 10000.00 root data:TableScan_18", + " │ └─TableScan_18 10000.00 cop[tikv] table:t1, range:[-inf,+inf], keep order:false, stats:pseudo", + " └─TableReader_19 8000.00 root data:Selection_20", + " └─Selection_20 8000.00 cop[tikv] gt(test.t2.b, 200)", + " └─TableScan_21 10000.00 cop[tikv] table:t2, range:[-inf,+inf], keep order:false, stats:pseudo" + ], + "Result": [ + "3 333" + ] } ] }, diff --git a/planner/cascades/transformation_rules.go b/planner/cascades/transformation_rules.go index 0d50d81db4..2a6220de65 100644 --- a/planner/cascades/transformation_rules.go +++ b/planner/cascades/transformation_rules.go @@ -680,6 +680,11 @@ func NewRulePushSelDownJoin() Transformation { return rule } +// Match implements Transformation interface. +func (r *PushSelDownJoin) Match(expr *memo.ExprIter) bool { + return !expr.GetExpr().HasAppliedRule(r) +} + // buildChildSelectionGroup builds a new childGroup if the pushed down condition is not empty. func buildChildSelectionGroup( oldSel *plannercore.LogicalSelection, @@ -706,9 +711,9 @@ func (r *PushSelDownJoin) OnTransform(old *memo.ExprIter) (newExprs []*memo.Grou leftGroup := old.Children[0].GetExpr().Children[0] rightGroup := old.Children[0].GetExpr().Children[1] var equalCond []*expression.ScalarFunction - var leftPushCond, rightPushCond, otherCond, leftCond, rightCond []expression.Expression + var leftPushCond, rightPushCond, otherCond, leftCond, rightCond, remainCond []expression.Expression switch join.JoinType { - case plannercore.InnerJoin: + case plannercore.SemiJoin, plannercore.InnerJoin: tempCond := make([]expression.Expression, 0, len(join.LeftConditions)+len(join.RightConditions)+len(join.EqualConditions)+len(join.OtherConditions)+len(sel.Conditions)) tempCond = append(tempCond, join.LeftConditions...) @@ -730,8 +735,59 @@ func (r *PushSelDownJoin) OnTransform(old *memo.ExprIter) (newExprs []*memo.Grou join.OtherConditions = otherCond leftCond = leftPushCond rightCond = rightPushCond + case plannercore.LeftOuterJoin, plannercore.LeftOuterSemiJoin, plannercore.AntiLeftOuterSemiJoin, + plannercore.RightOuterJoin: + lenJoinConds := len(join.EqualConditions) + len(join.LeftConditions) + len(join.RightConditions) + len(join.OtherConditions) + joinConds := make([]expression.Expression, 0, lenJoinConds) + for _, equalCond := range join.EqualConditions { + joinConds = append(joinConds, equalCond) + } + joinConds = append(joinConds, join.LeftConditions...) + joinConds = append(joinConds, join.RightConditions...) + joinConds = append(joinConds, join.OtherConditions...) + join.EqualConditions = nil + join.LeftConditions = nil + join.RightConditions = nil + join.OtherConditions = nil + remainCond = make([]expression.Expression, len(sel.Conditions)) + copy(remainCond, sel.Conditions) + nullSensitive := join.JoinType == plannercore.AntiLeftOuterSemiJoin || join.JoinType == plannercore.LeftOuterSemiJoin + if join.JoinType == plannercore.RightOuterJoin { + joinConds, remainCond = expression.PropConstOverOuterJoin(join.SCtx(), joinConds, remainCond, rightGroup.Prop.Schema, leftGroup.Prop.Schema, nullSensitive) + } else { + joinConds, remainCond = expression.PropConstOverOuterJoin(join.SCtx(), joinConds, remainCond, leftGroup.Prop.Schema, rightGroup.Prop.Schema, nullSensitive) + } + join.AttachOnConds(joinConds) + // Return table dual when filter is constant false or null. + dual := plannercore.Conds2TableDual(join, remainCond) + if dual != nil { + return []*memo.GroupExpr{memo.NewGroupExpr(dual)}, false, true, nil + } + if join.JoinType == plannercore.RightOuterJoin { + remainCond = expression.ExtractFiltersFromDNFs(join.SCtx(), remainCond) + // Only derive right where condition, because left where condition cannot be pushed down + equalCond, leftPushCond, rightPushCond, otherCond = join.ExtractOnCondition(remainCond, leftGroup.Prop.Schema, rightGroup.Prop.Schema, false, true) + rightCond = rightPushCond + // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down + derivedLeftJoinCond, _ := plannercore.DeriveOtherConditions(join, true, false) + leftCond = append(join.LeftConditions, derivedLeftJoinCond...) + join.LeftConditions = nil + remainCond = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) + remainCond = append(remainCond, leftPushCond...) + } else { + remainCond = expression.ExtractFiltersFromDNFs(join.SCtx(), remainCond) + // Only derive left where condition, because right where condition cannot be pushed down + equalCond, leftPushCond, rightPushCond, otherCond = join.ExtractOnCondition(remainCond, leftGroup.Prop.Schema, rightGroup.Prop.Schema, true, false) + leftCond = leftPushCond + // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down + _, derivedRightJoinCond := plannercore.DeriveOtherConditions(join, false, true) + rightCond = append(join.RightConditions, derivedRightJoinCond...) + join.RightConditions = nil + remainCond = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) + remainCond = append(remainCond, rightPushCond...) + } default: - // TODO: Enhance this rule to deal with LeftOuter/RightOuter/Semi/SmiAnti/LeftOuterSemi/LeftOuterSemiAnti Joins. + // TODO: Enhance this rule to deal with Semi/SmiAnti Joins. } leftCond = expression.RemoveDupExprs(sctx, leftCond) rightCond = expression.RemoveDupExprs(sctx, rightCond) @@ -740,10 +796,18 @@ func (r *PushSelDownJoin) OnTransform(old *memo.ExprIter) (newExprs []*memo.Grou join.RightJoinKeys = append(join.RightJoinKeys, eqCond.GetArgs()[1].(*expression.Column)) } // TODO: Update EqualConditions like what we have done in the method join.updateEQCond() before. - leftGroup = buildChildSelectionGroup(sel, leftCond, joinExpr.Children[0]) - rightGroup = buildChildSelectionGroup(sel, rightCond, joinExpr.Children[1]) + leftGroup = buildChildSelectionGroup(sel, leftCond, leftGroup) + rightGroup = buildChildSelectionGroup(sel, rightCond, rightGroup) newJoinExpr := memo.NewGroupExpr(join) newJoinExpr.SetChildren(leftGroup, rightGroup) + if len(remainCond) > 0 { + newSel := plannercore.LogicalSelection{Conditions: remainCond}.Init(sctx, sel.SelectBlockOffset()) + newSel.Conditions = remainCond + newSelExpr := memo.NewGroupExpr(newSel) + newSelExpr.SetChildren(memo.NewGroupWithSchema(newJoinExpr, old.Children[0].Prop.Schema)) + newSelExpr.AddAppliedRule(r) + return []*memo.GroupExpr{newSelExpr}, true, false, nil + } return []*memo.GroupExpr{newJoinExpr}, true, false, nil } diff --git a/planner/core/expression_rewriter.go b/planner/core/expression_rewriter.go index 5c6b0f6cb6..64944408de 100644 --- a/planner/core/expression_rewriter.go +++ b/planner/core/expression_rewriter.go @@ -814,7 +814,7 @@ func (er *expressionRewriter) handleInSubquery(ctx context.Context, v *ast.Patte join.names = make([]*types.FieldName, er.p.Schema().Len()+agg.Schema().Len()) copy(join.names, er.p.OutputNames()) copy(join.names[er.p.Schema().Len():], agg.OutputNames()) - join.attachOnConds(expression.SplitCNFItems(checkCondition)) + join.AttachOnConds(expression.SplitCNFItems(checkCondition)) // Set join hint for this join. if er.b.TableHints() != nil { join.setPreferredJoinType(er.b.TableHints()) diff --git a/planner/core/logical_plan_builder.go b/planner/core/logical_plan_builder.go index 4496dc401c..ac1e42acca 100644 --- a/planner/core/logical_plan_builder.go +++ b/planner/core/logical_plan_builder.go @@ -574,7 +574,7 @@ func (b *PlanBuilder) buildJoin(ctx context.Context, joinNode *ast.Join) (Logica return nil, errors.New("ON condition doesn't support subqueries yet") } onCondition := expression.SplitCNFItems(onExpr) - joinPlan.attachOnConds(onCondition) + joinPlan.AttachOnConds(onCondition) } else if joinPlan.JoinType == InnerJoin { // If a inner join without "ON" or "USING" clause, it's a cartesian // product over the join tables. @@ -2860,7 +2860,7 @@ func (b *PlanBuilder) buildSemiJoin(outerPlan, innerPlan LogicalPlan, onConditio onCondition[i] = expr.Decorrelate(outerPlan.Schema()) } joinPlan.SetChildren(outerPlan, innerPlan) - joinPlan.attachOnConds(onCondition) + joinPlan.AttachOnConds(onCondition) joinPlan.names = make([]*types.FieldName, outerPlan.Schema().Len(), outerPlan.Schema().Len()+innerPlan.Schema().Len()+1) copy(joinPlan.names, outerPlan.OutputNames()) if asScalar { diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index b015bc1280..536afda269 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -200,7 +200,9 @@ func (p *LogicalJoin) columnSubstitute(schema *expression.Schema, exprs []expres } } -func (p *LogicalJoin) attachOnConds(onConds []expression.Expression) { +// AttachOnConds extracts on conditions for join and set the `EqualConditions`, `LeftConditions`, `RightConditions` and +// `OtherConditions` by the result of extract. +func (p *LogicalJoin) AttachOnConds(onConds []expression.Expression) { eq, left, right, other := p.extractOnCondition(onConds, false, false) p.EqualConditions = append(eq, p.EqualConditions...) p.LeftConditions = append(left, p.LeftConditions...) diff --git a/planner/core/rule_decorrelate.go b/planner/core/rule_decorrelate.go index 63173c688a..67c1c67a99 100644 --- a/planner/core/rule_decorrelate.go +++ b/planner/core/rule_decorrelate.go @@ -115,7 +115,7 @@ func (s *decorrelateSolver) optimize(ctx context.Context, p LogicalPlan) (Logica for _, cond := range sel.Conditions { newConds = append(newConds, cond.Decorrelate(outerPlan.Schema())) } - apply.attachOnConds(newConds) + apply.AttachOnConds(newConds) innerPlan = sel.children[0] apply.SetChildren(outerPlan, innerPlan) return s.optimize(ctx, p) diff --git a/planner/core/rule_predicate_push_down.go b/planner/core/rule_predicate_push_down.go index 9fcaa14c11..2e6e99d551 100644 --- a/planner/core/rule_predicate_push_down.go +++ b/planner/core/rule_predicate_push_down.go @@ -130,7 +130,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, true, false) leftCond = leftPushCond // Handle join conditions, only derive right join condition, because left join condition cannot be pushed down - _, derivedRightJoinCond := deriveOtherConditions(p, false, true) + _, derivedRightJoinCond := DeriveOtherConditions(p, false, true) rightCond = append(p.RightConditions, derivedRightJoinCond...) p.RightConditions = nil ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) @@ -147,7 +147,7 @@ func (p *LogicalJoin) PredicatePushDown(predicates []expression.Expression) (ret equalCond, leftPushCond, rightPushCond, otherCond = p.extractOnCondition(predicates, false, true) rightCond = rightPushCond // Handle join conditions, only derive left join condition, because right join condition cannot be pushed down - derivedLeftJoinCond, _ := deriveOtherConditions(p, true, false) + derivedLeftJoinCond, _ := DeriveOtherConditions(p, true, false) leftCond = append(p.LeftConditions, derivedLeftJoinCond...) p.LeftConditions = nil ret = append(expression.ScalarFuncs2Exprs(equalCond), otherCond...) @@ -443,9 +443,9 @@ func (p *LogicalMaxOneRow) PredicatePushDown(predicates []expression.Expression) return predicates, p } -// deriveOtherConditions given a LogicalJoin, check the OtherConditions to see if we can derive more +// DeriveOtherConditions given a LogicalJoin, check the OtherConditions to see if we can derive more // conditions for left/right child pushdown. -func deriveOtherConditions(p *LogicalJoin, deriveLeft bool, deriveRight bool) (leftCond []expression.Expression, +func DeriveOtherConditions(p *LogicalJoin, deriveLeft bool, deriveRight bool) (leftCond []expression.Expression, rightCond []expression.Expression) { leftPlan, rightPlan := p.children[0], p.children[1] isOuterSemi := (p.JoinType == LeftOuterSemiJoin) || (p.JoinType == AntiLeftOuterSemiJoin) @@ -545,9 +545,9 @@ func (p *LogicalJoin) outerJoinPropConst(predicates []expression.Expression) []e p.LeftConditions = nil p.RightConditions = nil p.OtherConditions = nil - nullSensitive := (p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin) + nullSensitive := p.JoinType == AntiLeftOuterSemiJoin || p.JoinType == LeftOuterSemiJoin joinConds, predicates = expression.PropConstOverOuterJoin(p.ctx, joinConds, predicates, outerTable.Schema(), innerTable.Schema(), nullSensitive) - p.attachOnConds(joinConds) + p.AttachOnConds(joinConds) return predicates }