diff --git a/planner/core/rule_join_reorder.go b/planner/core/rule_join_reorder.go index ecbbd894fb..085ee11912 100644 --- a/planner/core/rule_join_reorder.go +++ b/planner/core/rule_join_reorder.go @@ -72,7 +72,7 @@ func getCartesianJoinGroup(p *LogicalJoin) []LogicalPlan { return append(lhsJoinGroup, rChild) } -func findColumnIndexByGroup(groups []LogicalPlan, col *expression.Column) int { +func findNodeIndexInGroup(groups []LogicalPlan, col *expression.Column) int { for i, plan := range groups { if plan.Schema().Contains(col) { return i @@ -143,8 +143,8 @@ func (e *joinReOrderSolver) reorderJoin(group []LogicalPlan, conds []expression. lCol, lok := f.GetArgs()[0].(*expression.Column) rCol, rok := f.GetArgs()[1].(*expression.Column) if lok && rok { - lID := findColumnIndexByGroup(group, lCol) - rID := findColumnIndexByGroup(group, rCol) + lID := findNodeIndexInGroup(group, lCol) + rID := findNodeIndexInGroup(group, rCol) if lID != rID { e.graph[lID] = append(e.graph[lID], &rankInfo{nodeID: rID}) e.graph[rID] = append(e.graph[rID], &rankInfo{nodeID: lID}) @@ -156,7 +156,7 @@ func (e *joinReOrderSolver) reorderJoin(group []LogicalPlan, conds []expression. rate := 1.0 cols := expression.ExtractColumns(f) for _, col := range cols { - idx := findColumnIndexByGroup(group, col) + idx := findNodeIndexInGroup(group, col) if id == -1 { switch f.FuncName.L { case ast.EQ: diff --git a/planner/core/rule_join_reorder_dp.go b/planner/core/rule_join_reorder_dp.go new file mode 100644 index 0000000000..0864bf864e --- /dev/null +++ b/planner/core/rule_join_reorder_dp.go @@ -0,0 +1,181 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "math/bits" + + "github.com/pingcap/parser/ast" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/sessionctx" +) + +type joinReorderDPSolver struct { + ctx sessionctx.Context + newJoin func(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan +} + +type joinGroupEdge struct { + nodeIDs []int + edge *expression.ScalarFunction +} + +func (s *joinReorderDPSolver) solve(joinGroup []LogicalPlan, conds []expression.Expression) (LogicalPlan, error) { + adjacents := make([][]int, len(joinGroup)) + totalEdges := make([]joinGroupEdge, 0, len(conds)) + addEdge := func(node1, node2 int, edgeContent *expression.ScalarFunction) { + totalEdges = append(totalEdges, joinGroupEdge{ + nodeIDs: []int{node1, node2}, + edge: edgeContent, + }) + adjacents[node1] = append(adjacents[node1], node2) + adjacents[node2] = append(adjacents[node2], node1) + } + // Build Graph for join group + for _, cond := range conds { + sf := cond.(*expression.ScalarFunction) + lCol := sf.GetArgs()[0].(*expression.Column) + rCol := sf.GetArgs()[1].(*expression.Column) + lIdx := findNodeIndexInGroup(joinGroup, lCol) + rIdx := findNodeIndexInGroup(joinGroup, rCol) + addEdge(lIdx, rIdx, sf) + } + visited := make([]bool, len(joinGroup)) + nodeID2VisitID := make([]int, len(joinGroup)) + var joins []LogicalPlan + // BFS the tree. + for i := 0; i < len(joinGroup); i++ { + if visited[i] { + continue + } + visitID2NodeID := s.bfsGraph(i, visited, adjacents, nodeID2VisitID) + // Do DP on each sub graph. + join, err := s.dpGraph(visitID2NodeID, nodeID2VisitID, joinGroup, totalEdges) + if err != nil { + return nil, err + } + joins = append(joins, join) + } + // Build bushy tree for cartesian joins. + return s.makeBushyJoin(joins), nil +} + +// bfsGraph bfs a sub graph starting at startPos. And relabel its label for future use. +func (s *joinReorderDPSolver) bfsGraph(startNode int, visited []bool, adjacents [][]int, nodeID2VistID []int) []int { + queue := []int{startNode} + visited[startNode] = true + var visitID2NodeID []int + for len(queue) > 0 { + curNodeID := queue[0] + queue = queue[1:] + nodeID2VistID[curNodeID] = len(visitID2NodeID) + visitID2NodeID = append(visitID2NodeID, curNodeID) + for _, adjNodeID := range adjacents[curNodeID] { + if visited[adjNodeID] { + continue + } + queue = append(queue, adjNodeID) + visited[adjNodeID] = true + } + } + return visitID2NodeID +} + +func (s *joinReorderDPSolver) dpGraph(newPos2OldPos, oldPos2NewPos []int, joinGroup []LogicalPlan, totalEdges []joinGroupEdge) (LogicalPlan, error) { + nodeCnt := uint(len(newPos2OldPos)) + bestPlan := make([]LogicalPlan, 1< 0; sub = (sub - 1) & nodeBitmap { + remain := nodeBitmap ^ sub + if sub > remain { + continue + } + // If this subset is not connected skip it. + if bestPlan[sub] == nil || bestPlan[remain] == nil { + continue + } + // Get the edge connecting the two parts. + usedEdges := s.nodesAreConnected(sub, remain, oldPos2NewPos, totalEdges) + if len(usedEdges) == 0 { + continue + } + join, err := s.newJoinWithEdge(bestPlan[sub], bestPlan[remain], usedEdges) + if err != nil { + return nil, err + } + if bestPlan[nodeBitmap] == nil || bestCost[nodeBitmap] > join.statsInfo().Count()+bestCost[remain]+bestCost[sub] { + bestPlan[nodeBitmap] = join + bestCost[nodeBitmap] = join.statsInfo().Count() + bestCost[remain] + bestCost[sub] + } + } + } + return bestPlan[(1< 0 && (rightMask&(1< 0 { + usedEdges = append(usedEdges, edge) + } else if (leftMask&(1< 0 && (rightMask&(1< 0 { + usedEdges = append(usedEdges, edge) + } + } + return usedEdges +} + +func (s *joinReorderDPSolver) newJoinWithEdge(leftPlan, rightPlan LogicalPlan, edges []joinGroupEdge) (LogicalPlan, error) { + var eqConds []*expression.ScalarFunction + for _, edge := range edges { + lCol := edge.edge.GetArgs()[0].(*expression.Column) + rCol := edge.edge.GetArgs()[1].(*expression.Column) + if leftPlan.Schema().Contains(lCol) { + eqConds = append(eqConds, edge.edge) + } else { + newSf := expression.NewFunctionInternal(s.ctx, ast.EQ, edge.edge.GetType(), rCol, lCol).(*expression.ScalarFunction) + eqConds = append(eqConds, newSf) + } + } + join := s.newJoin(leftPlan, rightPlan, eqConds) + _, err := join.deriveStats() + return join, err +} + +// Make cartesian join as bushy tree. +func (s *joinReorderDPSolver) makeBushyJoin(cartesianJoinGroup []LogicalPlan) LogicalPlan { + for len(cartesianJoinGroup) > 1 { + resultJoinGroup := make([]LogicalPlan, 0, len(cartesianJoinGroup)) + for i := 0; i < len(cartesianJoinGroup); i += 2 { + if i+1 == len(cartesianJoinGroup) { + resultJoinGroup = append(resultJoinGroup, cartesianJoinGroup[i]) + break + } + resultJoinGroup = append(resultJoinGroup, s.newJoin(cartesianJoinGroup[i], cartesianJoinGroup[i+1], nil)) + } + cartesianJoinGroup = resultJoinGroup + } + return cartesianJoinGroup[0] +} diff --git a/planner/core/rule_join_reorder_dp_test.go b/planner/core/rule_join_reorder_dp_test.go new file mode 100644 index 0000000000..dfaebc44a7 --- /dev/null +++ b/planner/core/rule_join_reorder_dp_test.go @@ -0,0 +1,213 @@ +// Copyright 2018 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package core + +import ( + "fmt" + + . "github.com/pingcap/check" + "github.com/pingcap/parser/ast" + "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/planner/property" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/types" +) + +var _ = Suite(&testJoinReorderDPSuite{}) + +type testJoinReorderDPSuite struct { + ctx sessionctx.Context + statsMap map[int]*property.StatsInfo +} + +func (s *testJoinReorderDPSuite) SetUpTest(c *C) { + s.ctx = mockContext() + s.ctx.GetSessionVars().PlanID = -1 +} + +type mockLogicalJoin struct { + logicalSchemaProducer + involvedNodeSet int + statsMap map[int]*property.StatsInfo +} + +func (mj mockLogicalJoin) init(ctx sessionctx.Context) *mockLogicalJoin { + mj.baseLogicalPlan = newBaseLogicalPlan(ctx, "MockLogicalJoin", &mj) + return &mj +} + +func (mj *mockLogicalJoin) deriveStats() (*property.StatsInfo, error) { + if mj.stats == nil { + mj.stats = mj.statsMap[mj.involvedNodeSet] + } + return mj.statsMap[mj.involvedNodeSet], nil +} + +func (s *testJoinReorderDPSuite) newMockJoin(lChild, rChild LogicalPlan, eqConds []*expression.ScalarFunction) LogicalPlan { + retJoin := mockLogicalJoin{}.init(s.ctx) + retJoin.schema = expression.MergeSchema(lChild.Schema(), rChild.Schema()) + retJoin.statsMap = s.statsMap + if mj, ok := lChild.(*mockLogicalJoin); ok { + retJoin.involvedNodeSet = mj.involvedNodeSet + } else { + retJoin.involvedNodeSet = 1 << uint(lChild.ID()) + } + if mj, ok := rChild.(*mockLogicalJoin); ok { + retJoin.involvedNodeSet |= mj.involvedNodeSet + } else { + retJoin.involvedNodeSet |= 1 << uint(rChild.ID()) + } + retJoin.SetChildren(lChild, rChild) + return retJoin +} + +func (s *testJoinReorderDPSuite) mockStatsInfo(state int, count float64) { + s.statsMap[state] = &property.StatsInfo{ + RowCount: count, + } +} + +func (s *testJoinReorderDPSuite) makeStatsMapForTPCHQ5() { + // Labeled as lineitem -> 0, orders -> 1, customer -> 2, supplier 3, nation 4, region 5 + // This graph can be shown as following: + // +---------------+ +---------------+ + // | | | | + // | lineitem +------------+ orders | + // | | | | + // +-------+-------+ +-------+-------+ + // | | + // | | + // | | + // +-------+-------+ +-------+-------+ + // | | | | + // | supplier +------------+ customer | + // | | | | + // +-------+-------+ +-------+-------+ + // | | + // | | + // | | + // | | + // | +---------------+ | + // | | | | + // +------+ nation +-----+ + // | | + // +---------------+ + // | + // +---------------+ + // | | + // | region | + // | | + // +---------------+ + s.statsMap = make(map[int]*property.StatsInfo) + s.mockStatsInfo(3, 9103367) + s.mockStatsInfo(6, 2275919) + s.mockStatsInfo(7, 9103367) + s.mockStatsInfo(9, 59986052) + s.mockStatsInfo(11, 9103367) + s.mockStatsInfo(12, 5999974575) + s.mockStatsInfo(13, 59999974575) + s.mockStatsInfo(14, 9103543072) + s.mockStatsInfo(15, 99103543072) + s.mockStatsInfo(20, 1500000) + s.mockStatsInfo(22, 2275919) + s.mockStatsInfo(23, 7982159) + s.mockStatsInfo(24, 100000) + s.mockStatsInfo(25, 59986052) + s.mockStatsInfo(27, 9103367) + s.mockStatsInfo(28, 5999974575) + s.mockStatsInfo(29, 59999974575) + s.mockStatsInfo(30, 59999974575) + s.mockStatsInfo(31, 59999974575) + s.mockStatsInfo(48, 5) + s.mockStatsInfo(52, 299838) + s.mockStatsInfo(54, 454183) + s.mockStatsInfo(55, 1815222) + s.mockStatsInfo(56, 20042) + s.mockStatsInfo(57, 12022687) + s.mockStatsInfo(59, 1823514) + s.mockStatsInfo(60, 1201884359) + s.mockStatsInfo(61, 12001884359) + s.mockStatsInfo(62, 12001884359) + s.mockStatsInfo(63, 72985) + +} + +func (s *testJoinReorderDPSuite) newDataSource(name string) LogicalPlan { + ds := DataSource{}.Init(s.ctx) + tan := model.NewCIStr(name) + ds.TableAsName = &tan + ds.schema = expression.NewSchema() + s.ctx.GetSessionVars().PlanColumnID++ + ds.schema.Append(&expression.Column{ + UniqueID: s.ctx.GetSessionVars().PlanColumnID, + ColName: model.NewCIStr(fmt.Sprintf("%s_a", name)), + TblName: model.NewCIStr(name), + DBName: model.NewCIStr("test"), + RetType: types.NewFieldType(mysql.TypeLonglong), + }) + return ds +} + +func (s *testJoinReorderDPSuite) planToString(plan LogicalPlan) string { + switch x := plan.(type) { + case *mockLogicalJoin: + return fmt.Sprintf("MockJoin{%v, %v}", s.planToString(x.children[0]), s.planToString(x.children[1])) + case *DataSource: + return x.TableAsName.L + } + return "" +} + +func (s *testJoinReorderDPSuite) TestDPReorderTPCHQ5(c *C) { + s.makeStatsMapForTPCHQ5() + joinGroups := make([]LogicalPlan, 0, 6) + joinGroups = append(joinGroups, s.newDataSource("lineitem")) + joinGroups = append(joinGroups, s.newDataSource("orders")) + joinGroups = append(joinGroups, s.newDataSource("customer")) + joinGroups = append(joinGroups, s.newDataSource("supplier")) + joinGroups = append(joinGroups, s.newDataSource("nation")) + joinGroups = append(joinGroups, s.newDataSource("region")) + var eqConds []expression.Expression + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[1].Schema().Columns[0])) + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[1].Schema().Columns[0], joinGroups[2].Schema().Columns[0])) + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[3].Schema().Columns[0])) + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[0].Schema().Columns[0], joinGroups[3].Schema().Columns[0])) + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[2].Schema().Columns[0], joinGroups[4].Schema().Columns[0])) + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[3].Schema().Columns[0], joinGroups[4].Schema().Columns[0])) + eqConds = append(eqConds, expression.NewFunctionInternal(s.ctx, ast.EQ, types.NewFieldType(mysql.TypeTiny), joinGroups[4].Schema().Columns[0], joinGroups[5].Schema().Columns[0])) + solver := &joinReorderDPSolver{ + ctx: s.ctx, + newJoin: s.newMockJoin, + } + result, err := solver.solve(joinGroups, eqConds) + c.Assert(err, IsNil) + c.Assert(s.planToString(result), Equals, "MockJoin{supplier, MockJoin{lineitem, MockJoin{orders, MockJoin{customer, MockJoin{nation, region}}}}}") +} + +func (s *testJoinReorderDPSuite) TestDPReorderAllCartesian(c *C) { + joinGroup := make([]LogicalPlan, 0, 4) + joinGroup = append(joinGroup, s.newDataSource("a")) + joinGroup = append(joinGroup, s.newDataSource("b")) + joinGroup = append(joinGroup, s.newDataSource("c")) + joinGroup = append(joinGroup, s.newDataSource("d")) + solver := &joinReorderDPSolver{ + ctx: s.ctx, + newJoin: s.newMockJoin, + } + result, err := solver.solve(joinGroup, nil) + c.Assert(err, IsNil) + c.Assert(s.planToString(result), Equals, "MockJoin{MockJoin{a, b}, MockJoin{c, d}}") +}