From fe073242900a8b464c2f9b68c29057bd812cf01e Mon Sep 17 00:00:00 2001 From: Chengpeng Yan <41809508+Reminiscent@users.noreply.github.com> Date: Wed, 11 May 2022 21:10:34 +0800 Subject: [PATCH] planner: refactor the join reorder codes (#34380) ref pingcap/tidb#29932 --- planner/core/rule_join_reorder.go | 73 ++++++++++++++++++++++- planner/core/rule_join_reorder_dp.go | 3 +- planner/core/rule_join_reorder_dp_test.go | 22 ++++--- planner/core/rule_join_reorder_greedy.go | 57 ++---------------- 4 files changed, 93 insertions(+), 62 deletions(-) diff --git a/planner/core/rule_join_reorder.go b/planner/core/rule_join_reorder.go index d060c708f9..b08bd76bb7 100644 --- a/planner/core/rule_join_reorder.go +++ b/planner/core/rule_join_reorder.go @@ -21,6 +21,7 @@ import ( "sort" "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/util/plancodec" "github.com/pingcap/tidb/util/tracing" @@ -102,7 +103,10 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP baseGroupSolver := &baseSingleGroupJoinOrderSolver{ ctx: ctx, otherConds: otherConds, + eqEdges: eqEdges, + joinTypes: joinTypes, } + originalSchema := p.Schema() // Not support outer join reorder when using the DP algorithm @@ -116,8 +120,6 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP if len(curJoinGroup) > ctx.GetSessionVars().TiDBOptJoinReorderThreshold || !isSupportDP { groupSolver := &joinReorderGreedySolver{ baseSingleGroupJoinOrderSolver: baseGroupSolver, - eqEdges: eqEdges, - joinTypes: joinTypes, } p, err = groupSolver.solve(curJoinGroup, tracer) } else { @@ -125,7 +127,7 @@ func (s *joinReOrderSolver) optimizeRecursive(ctx sessionctx.Context, p LogicalP baseSingleGroupJoinOrderSolver: baseGroupSolver, } dpSolver.newJoin = dpSolver.newJoinWithEdges - p, err = dpSolver.solve(curJoinGroup, expression.ScalarFuncs2Exprs(eqEdges), tracer) + p, err = dpSolver.solve(curJoinGroup, tracer) } if err != nil { return nil, err @@ -168,6 +170,26 @@ type baseSingleGroupJoinOrderSolver struct { ctx sessionctx.Context curJoinGroup []*jrNode otherConds []expression.Expression + eqEdges []*expression.ScalarFunction + joinTypes []JoinType +} + +// generateJoinOrderNode used to derive the stats for the joinNodePlans and generate the jrNode groups based on the cost. +func (s *baseSingleGroupJoinOrderSolver) generateJoinOrderNode(joinNodePlans []LogicalPlan, tracer *joinReorderTrace) ([]*jrNode, error) { + joinGroup := make([]*jrNode, 0, len(joinNodePlans)) + for _, node := range joinNodePlans { + _, err := node.recursiveDeriveStats(nil) + if err != nil { + return nil, err + } + cost := s.baseNodeCumCost(node) + joinGroup = append(joinGroup, &jrNode{ + p: node, + cumCost: cost, + }) + tracer.appendLogicalJoinCost(node, cost) + } + return joinGroup, nil } // baseNodeCumCost calculate the cumulative cost of the node in the join group. @@ -179,6 +201,51 @@ func (s *baseSingleGroupJoinOrderSolver) baseNodeCumCost(groupNode LogicalPlan) return cost } +// checkConnection used to check whether two nodes have equal conditions or not. +func (s *baseSingleGroupJoinOrderSolver) checkConnection(leftPlan, rightPlan LogicalPlan) (leftNode, rightNode LogicalPlan, usedEdges []*expression.ScalarFunction, joinType JoinType) { + joinType = InnerJoin + leftNode, rightNode = leftPlan, rightPlan + for idx, edge := range s.eqEdges { + lCol := edge.GetArgs()[0].(*expression.Column) + rCol := edge.GetArgs()[1].(*expression.Column) + if leftPlan.Schema().Contains(lCol) && rightPlan.Schema().Contains(rCol) { + joinType = s.joinTypes[idx] + usedEdges = append(usedEdges, edge) + } else if rightPlan.Schema().Contains(lCol) && leftPlan.Schema().Contains(rCol) { + joinType = s.joinTypes[idx] + if joinType != InnerJoin { + rightNode, leftNode = leftPlan, rightPlan + usedEdges = append(usedEdges, edge) + } else { + newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.GetType(), rCol, lCol).(*expression.ScalarFunction) + usedEdges = append(usedEdges, newSf) + } + } + } + return +} + +// makeJoin build join tree for the nodes which have equal conditions to connect them. +func (s *baseSingleGroupJoinOrderSolver) makeJoin(leftPlan, rightPlan LogicalPlan, eqEdges []*expression.ScalarFunction, joinType JoinType) (LogicalPlan, []expression.Expression) { + remainOtherConds := make([]expression.Expression, len(s.otherConds)) + copy(remainOtherConds, s.otherConds) + var otherConds []expression.Expression + var leftConds []expression.Expression + var rightConds []expression.Expression + mergedSchema := expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema()) + + remainOtherConds, leftConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { + return expression.ExprFromSchema(expr, leftPlan.Schema()) && !expression.ExprFromSchema(expr, rightPlan.Schema()) + }) + remainOtherConds, rightConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { + return expression.ExprFromSchema(expr, rightPlan.Schema()) && !expression.ExprFromSchema(expr, leftPlan.Schema()) + }) + remainOtherConds, otherConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { + return expression.ExprFromSchema(expr, mergedSchema) + }) + return s.newJoinWithEdges(leftPlan, rightPlan, eqEdges, otherConds, leftConds, rightConds, joinType), remainOtherConds +} + // makeBushyJoin build bushy tree for the nodes which have no equal condition to connect them. func (s *baseSingleGroupJoinOrderSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan { resultJoinGroup := make([]LogicalPlan, 0, (len(cartesianJoinGroup)+1)/2) diff --git a/planner/core/rule_join_reorder_dp.go b/planner/core/rule_join_reorder_dp.go index 58b3c28e89..c91d74e1b7 100644 --- a/planner/core/rule_join_reorder_dp.go +++ b/planner/core/rule_join_reorder_dp.go @@ -37,7 +37,8 @@ type joinGroupNonEqEdge struct { expr expression.Expression } -func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, eqConds []expression.Expression, tracer *joinReorderTrace) (LogicalPlan, error) { +func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, tracer *joinReorderTrace) (LogicalPlan, error) { + eqConds := expression.ScalarFuncs2Exprs(s.eqEdges) for _, node := range joinGroup { _, err := node.recursiveDeriveStats(nil) if err != nil { diff --git a/planner/core/rule_join_reorder_dp_test.go b/planner/core/rule_join_reorder_dp_test.go index df0bc3a649..d62fcbce0e 100644 --- a/planner/core/rule_join_reorder_dp_test.go +++ b/planner/core/rule_join_reorder_dp_test.go @@ -180,13 +180,21 @@ func TestDPReorderTPCHQ5(t *testing.T) { eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[4].Schema().Columns[0])) eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[3].Schema().Columns[0], joinGroups[4].Schema().Columns[0])) eqConds = append(eqConds, expression.NewFunctionInternal(ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[4].Schema().Columns[0], joinGroups[5].Schema().Columns[0])) - solver := &joinReorderDPSolver{ - baseSingleGroupJoinOrderSolver: &baseSingleGroupJoinOrderSolver{ - ctx: ctx, - }, - newJoin: newMockJoin(ctx, statsMap), + eqEdges := make([]*expression.ScalarFunction, 0, len(eqConds)) + for _, cond := range eqConds { + sf, isSF := cond.(*expression.ScalarFunction) + require.True(t, isSF) + eqEdges = append(eqEdges, sf) } - result, err := solver.solve(joinGroups, eqConds, nil) + baseGroupSolver := &baseSingleGroupJoinOrderSolver{ + ctx: ctx, + eqEdges: eqEdges, + } + solver := &joinReorderDPSolver{ + baseSingleGroupJoinOrderSolver: baseGroupSolver, + newJoin: newMockJoin(ctx, statsMap), + } + result, err := solver.solve(joinGroups, nil) require.NoError(t, err) expected := "MockJoin{supplier, MockJoin{lineitem, MockJoin{orders, MockJoin{customer, MockJoin{nation, region}}}}}" @@ -210,7 +218,7 @@ func TestDPReorderAllCartesian(t *testing.T) { }, newJoin: newMockJoin(ctx, statsMap), } - result, err := solver.solve(joinGroup, nil, nil) + result, err := solver.solve(joinGroup, nil) require.NoError(t, err) expected := "MockJoin{MockJoin{a, b}, MockJoin{c, d}}" diff --git a/planner/core/rule_join_reorder_greedy.go b/planner/core/rule_join_reorder_greedy.go index 525fb76a80..1d98067c7f 100644 --- a/planner/core/rule_join_reorder_greedy.go +++ b/planner/core/rule_join_reorder_greedy.go @@ -19,13 +19,10 @@ import ( "sort" "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/parser/ast" ) type joinReorderGreedySolver struct { *baseSingleGroupJoinOrderSolver - eqEdges []*expression.ScalarFunction - joinTypes []JoinType } // solve reorders the join nodes in the group based on a greedy algorithm. @@ -43,19 +40,11 @@ type joinReorderGreedySolver struct { // For the nodes and join trees which don't have a join equal condition to // connect them, we make a bushy join tree to do the cartesian joins finally. func (s *joinReorderGreedySolver) solve(joinNodePlans []LogicalPlan, tracer *joinReorderTrace) (LogicalPlan, error) { - for _, node := range joinNodePlans { - _, err := node.recursiveDeriveStats(nil) - if err != nil { - return nil, err - } - cost := s.baseNodeCumCost(node) - s.curJoinGroup = append(s.curJoinGroup, &jrNode{ - p: node, - cumCost: cost, - }) - tracer.appendLogicalJoinCost(node, cost) + var err error + s.curJoinGroup, err = s.generateJoinOrderNode(joinNodePlans, tracer) + if err != nil { + return nil, err } - // Sort plans by cost sort.SliceStable(s.curJoinGroup, func(i, j int) bool { return s.curJoinGroup[i].cumCost < s.curJoinGroup[j].cumCost @@ -114,43 +103,9 @@ func (s *joinReorderGreedySolver) constructConnectedJoinTree(tracer *joinReorder } func (s *joinReorderGreedySolver) checkConnectionAndMakeJoin(leftPlan, rightPlan LogicalPlan) (LogicalPlan, []expression.Expression) { - var usedEdges []*expression.ScalarFunction - remainOtherConds := make([]expression.Expression, len(s.otherConds)) - copy(remainOtherConds, s.otherConds) - joinType := InnerJoin - for idx, edge := range s.eqEdges { - lCol := edge.GetArgs()[0].(*expression.Column) - rCol := edge.GetArgs()[1].(*expression.Column) - if leftPlan.Schema().Contains(lCol) && rightPlan.Schema().Contains(rCol) { - joinType = s.joinTypes[idx] - usedEdges = append(usedEdges, edge) - } else if rightPlan.Schema().Contains(lCol) && leftPlan.Schema().Contains(rCol) { - joinType = s.joinTypes[idx] - if joinType != InnerJoin { - rightPlan, leftPlan = leftPlan, rightPlan - usedEdges = append(usedEdges, edge) - } else { - newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.GetType(), rCol, lCol).(*expression.ScalarFunction) - usedEdges = append(usedEdges, newSf) - } - } - } + leftPlan, rightPlan, usedEdges, joinType := s.checkConnection(leftPlan, rightPlan) if len(usedEdges) == 0 { return nil, nil } - var otherConds []expression.Expression - var leftConds []expression.Expression - var rightConds []expression.Expression - mergedSchema := expression.MergeSchema(leftPlan.Schema(), rightPlan.Schema()) - - remainOtherConds, leftConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { - return expression.ExprFromSchema(expr, leftPlan.Schema()) && !expression.ExprFromSchema(expr, rightPlan.Schema()) - }) - remainOtherConds, rightConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { - return expression.ExprFromSchema(expr, rightPlan.Schema()) && !expression.ExprFromSchema(expr, leftPlan.Schema()) - }) - remainOtherConds, otherConds = expression.FilterOutInPlace(remainOtherConds, func(expr expression.Expression) bool { - return expression.ExprFromSchema(expr, mergedSchema) - }) - return s.newJoinWithEdges(leftPlan, rightPlan, usedEdges, otherConds, leftConds, rightConds, joinType), remainOtherConds + return s.makeJoin(leftPlan, rightPlan, usedEdges, joinType) }