diff --git a/planner/core/rule_join_reorder_greedy.go b/planner/core/rule_join_reorder_greedy.go index 18da5b386e..f564010833 100644 --- a/planner/core/rule_join_reorder_greedy.go +++ b/planner/core/rule_join_reorder_greedy.go @@ -49,6 +49,9 @@ func extractJoinGroup(p LogicalPlan) (group []LogicalPlan, eqEdges []*expression } type joinReOrderGreedySolver struct { +} + +type joinReorderGreedySingleGroupSolver struct { ctx sessionctx.Context curJoinGroup []LogicalPlan eqEdges []*expression.ScalarFunction @@ -67,7 +70,7 @@ type joinReOrderGreedySolver struct { // // For the nodes and join trees which don't have a join equal condition to // connect them, we make a bushy join tree to do the cartesian joins finally. -func (s *joinReOrderGreedySolver) solve() (LogicalPlan, error) { +func (s *joinReorderGreedySingleGroupSolver) solve() (LogicalPlan, error) { for _, node := range s.curJoinGroup { _, err := node.deriveStats() if err != nil { @@ -90,7 +93,7 @@ func (s *joinReOrderGreedySolver) solve() (LogicalPlan, error) { return s.makeBushyJoin(cartesianGroup), nil } -func (s *joinReOrderGreedySolver) constructConnectedJoinTree() (LogicalPlan, error) { +func (s *joinReorderGreedySingleGroupSolver) constructConnectedJoinTree() (LogicalPlan, error) { curJoinTree := s.curJoinGroup[0] s.curJoinGroup = s.curJoinGroup[1:] for { @@ -126,7 +129,7 @@ func (s *joinReOrderGreedySolver) constructConnectedJoinTree() (LogicalPlan, err return curJoinTree, nil } -func (s *joinReOrderGreedySolver) checkConnectionAndMakeJoin(leftNode, rightNode LogicalPlan) (LogicalPlan, []expression.Expression) { +func (s *joinReorderGreedySingleGroupSolver) checkConnectionAndMakeJoin(leftNode, rightNode LogicalPlan) (LogicalPlan, []expression.Expression) { var usedEdges []*expression.ScalarFunction remainOtherConds := make([]expression.Expression, len(s.otherConds)) copy(remainOtherConds, s.otherConds) @@ -159,7 +162,7 @@ func (s *joinReOrderGreedySolver) checkConnectionAndMakeJoin(leftNode, rightNode return newJoin, remainOtherConds } -func (s *joinReOrderGreedySolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan { +func (s *joinReorderGreedySingleGroupSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan { resultJoinGroup := make([]LogicalPlan, 0, (len(cartesianJoinGroup)+1)/2) for len(cartesianJoinGroup) > 1 { resultJoinGroup = resultJoinGroup[:0] @@ -183,7 +186,7 @@ func (s *joinReOrderGreedySolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan return cartesianJoinGroup[0] } -func (s *joinReOrderGreedySolver) newCartesianJoin(lChild, rChild LogicalPlan) *LogicalJoin { +func (s *joinReorderGreedySingleGroupSolver) newCartesianJoin(lChild, rChild LogicalPlan) *LogicalJoin { join := LogicalJoin{ JoinType: InnerJoin, reordered: true, @@ -194,22 +197,26 @@ func (s *joinReOrderGreedySolver) newCartesianJoin(lChild, rChild LogicalPlan) * } func (s *joinReOrderGreedySolver) optimize(p LogicalPlan) (LogicalPlan, error) { - s.ctx = p.context() - return s.optimizeRecursive(p) + return s.optimizeRecursive(p.context(), p) } -func (s *joinReOrderGreedySolver) optimizeRecursive(p LogicalPlan) (LogicalPlan, error) { +func (s *joinReOrderGreedySolver) optimizeRecursive(ctx sessionctx.Context, p LogicalPlan) (LogicalPlan, error) { var err error curJoinGroup, eqEdges, otherConds := extractJoinGroup(p) if len(curJoinGroup) > 1 { for i := range curJoinGroup { - curJoinGroup[i], err = s.optimizeRecursive(curJoinGroup[i]) + curJoinGroup[i], err = s.optimizeRecursive(ctx, curJoinGroup[i]) if err != nil { return nil, err } } - s.curJoinGroup, s.eqEdges, s.otherConds = curJoinGroup, eqEdges, otherConds - p, err = s.solve() + groupSolver := &joinReorderGreedySingleGroupSolver{ + ctx: ctx, + curJoinGroup: curJoinGroup, + eqEdges: eqEdges, + otherConds: otherConds, + } + p, err = groupSolver.solve() if err != nil { return nil, err } @@ -217,7 +224,7 @@ func (s *joinReOrderGreedySolver) optimizeRecursive(p LogicalPlan) (LogicalPlan, } newChildren := make([]LogicalPlan, 0, len(p.Children())) for _, child := range p.Children() { - newChild, err := s.optimizeRecursive(child) + newChild, err := s.optimizeRecursive(ctx, child) if err != nil { return nil, err }