diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java index e6c247fdee..b2bdd4583a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/Edge.java @@ -19,10 +19,17 @@ package org.apache.doris.nereids.jobs.joinorder.hypergraph; import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import com.google.common.base.Preconditions; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + /** * Edge in HyperGraph */ @@ -35,6 +42,10 @@ public class Edge { // left and right may not overlap, and both must have at least one bit set. private long left = LongBitmap.newBitmap(); private long right = LongBitmap.newBitmap(); + + private long originalLeft = LongBitmap.newBitmap(); + private long originalRight = LongBitmap.newBitmap(); + private long referenceNodes = LongBitmap.newBitmap(); /** @@ -100,6 +111,22 @@ public class Edge { this.right = right; } + public long getOriginalLeft() { + return originalLeft; + } + + public void setOriginalLeft(long left) { + this.originalLeft = left; + } + + public long getOriginalRight() { + return originalRight; + } + + public void setOriginalRight(long right) { + this.originalRight = right; + } + public boolean isSub(Edge edge) { // When this join reference nodes is a subset of other join, then this join must appear before that join long otherBitmap = edge.getReferenceNodes(); @@ -122,9 +149,20 @@ public class Edge { } public Expression getExpression() { + Preconditions.checkArgument(join.getExpressions().size() == 1); return join.getExpressions().get(0); } + public List getExpressions() { + return join.getExpressions(); + } + + public final Set getInputSlots() { + Set slots = new HashSet<>(); + join.getExpressions().stream().forEach(expression -> slots.addAll(expression.getInputSlots())); + return slots; + } + @Override public String toString() { return String.format("<%s - %s>", LongBitmap.toString(left), LongBitmap.toString(right)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java index 2144f8bee6..4ca7ddcdfb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraph.java @@ -24,17 +24,18 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -131,13 +132,31 @@ public class HyperGraph { public void addEdge(Group group) { Preconditions.checkArgument(group.isJoinGroup()); LogicalJoin join = (LogicalJoin) group.getLogicalExpression().getPlan(); - for (Expression expression : join.getExpressions()) { - LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), ImmutableList.of(expression), join.left(), - join.right()); - Edge edge = new Edge(singleJoin, edges.size()); + HashMap, Pair, List>> conjuncts = new HashMap<>(); + for (Expression expression : join.getHashJoinConjuncts()) { Pair ends = findEnds(expression); + if (!conjuncts.containsKey(ends)) { + conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); + } + conjuncts.get(ends).first.add(expression); + } + for (Expression expression : join.getOtherJoinConjuncts()) { + Pair ends = findEnds(expression); + if (!conjuncts.containsKey(ends)) { + conjuncts.put(ends, Pair.of(new ArrayList<>(), new ArrayList<>())); + } + conjuncts.get(ends).second.add(expression); + } + for (Map.Entry, Pair, List>> entry : conjuncts + .entrySet()) { + LogicalJoin singleJoin = new LogicalJoin<>(join.getJoinType(), entry.getValue().first, + entry.getValue().second, JoinHint.NONE, join.left(), join.right()); + Edge edge = new Edge(singleJoin, edges.size()); + Pair ends = entry.getKey(); edge.setLeft(ends.first); + edge.setOriginalLeft(ends.first); edge.setRight(ends.second); + edge.setOriginalRight(ends.second); for (int nodeIndex : LongBitmap.getIterator(edge.getReferenceNodes())) { nodes.get(nodeIndex).attachEdge(edge); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java index 09ff637436..4e38346e6b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/receiver/PlanReceiver.java @@ -105,6 +105,12 @@ public class PlanReceiver implements AbstractReceiver { Preconditions.checkArgument(planTable.containsKey(left)); Preconditions.checkArgument(planTable.containsKey(right)); + // check if the missed edges can be correctly connected by add it to edges + // if not, the plan is invalid because of the missed edges, just return and seek for another valid plan + if (!processMissedEdges(left, right, edges)) { + return true; + } + Memo memo = jobContext.getCascadesContext().getMemo(); emitCount += 1; if (emitCount > limit) { @@ -151,7 +157,7 @@ public class PlanReceiver implements AbstractReceiver { usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap); for (Edge edge : hyperGraph.getEdges()) { if (!usedEdgesBitmap.get(edge.getIndex())) { - outputSlots.addAll(edge.getExpression().getInputSlots()); + outputSlots.addAll(edge.getInputSlots()); } } hyperGraph.getComplexProject() @@ -162,6 +168,47 @@ public class PlanReceiver implements AbstractReceiver { return outputSlots; } + // check if the missed edges can be used to connect left and right together with edges + // return true if no missed edge or the missed edge can be used to connect left and right + // the returned edges includes missed edges if there is any. + private boolean processMissedEdges(long left, long right, List edges) { + boolean canAddMisssedEdges = true; + + // find all reference nodes assume left and right sub graph is connected + BitSet usedEdgesBitmap = new BitSet(); + usedEdgesBitmap.or(usdEdges.get(left)); + usedEdgesBitmap.or(usdEdges.get(right)); + edges.stream().forEach(edge -> usedEdgesBitmap.set(edge.getIndex())); + long allReferenceNodes = getAllReferenceNodes(usedEdgesBitmap); + + // check all edges + // the edge is a missed edge if the edge is not used and its reference nodes is a subset of allReferenceNodes + for (Edge edge : hyperGraph.getEdges()) { + if (LongBitmap.isSubset(edge.getReferenceNodes(), allReferenceNodes) && !usedEdgesBitmap.get( + edge.getIndex())) { + // check the missed edge can be used to connect left and right together with edges + // if the missed edge meet the 2 conditions, it is a valid edge + // 1. the edge's left child's referenced nodes is subset of the left + // 2. the edge's original right node is subset of right + canAddMisssedEdges = canAddMisssedEdges && LongBitmap.isSubset(edge.getLeft(), + left) && LongBitmap.isSubset(edge.getOriginalRight(), right); + + // always add the missed edge to edges + // because the caller will return immediately if canAddMisssedEdges is false + edges.add(edge); + } + } + return canAddMisssedEdges; + } + + private long getAllReferenceNodes(BitSet edgesBitmap) { + long nodes = LongBitmap.newBitmap(); + for (int i = edgesBitmap.nextSetBit(0); i >= 0; i = edgesBitmap.nextSetBit(i + 1)) { + nodes = LongBitmap.or(nodes, hyperGraph.getEdge(i).getReferenceNodes()); + } + return nodes; + } + private void proposeAllDistributedPlans(GroupExpression groupExpression) { jobContext.getCascadesContext().pushJob(new CostAndEnforcerJob(groupExpression, new JobContext(jobContext.getCascadesContext(), PhysicalProperties.ANY, Double.MAX_VALUE))); @@ -200,11 +247,12 @@ public class PlanReceiver implements AbstractReceiver { for (Edge edge : edges) { Preconditions.checkArgument(joinType == null || joinType == edge.getJoinType()); joinType = edge.getJoinType(); - Expression expression = edge.getExpression(); - if (expression instanceof EqualTo) { - hashConjuncts.add(edge.getExpression()); - } else { - otherConjuncts.add(expression); + for (Expression expression : edge.getExpressions()) { + if (expression instanceof EqualTo) { + hashConjuncts.add(expression); + } else { + otherConjuncts.add(expression); + } } } return joinType; @@ -231,6 +279,8 @@ public class PlanReceiver implements AbstractReceiver { @Override public void reset() { planTable.clear(); + projectsOnSubgraph.clear(); + usdEdges.clear(); emitCount = 0; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java index 65814b7252..fc079c2887 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/HyperGraphTest.java @@ -85,10 +85,10 @@ public class HyperGraphTest { + " LOGICAL_OLAP_SCAN3 [label=\"LOGICAL_OLAP_SCAN3 \n" + " rowCount=40.00\"];\n" + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN1 [label=\"1.00\",arrowhead=none]\n" - + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n" + "LOGICAL_OLAP_SCAN1 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n" - + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n" + + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN2 [label=\"1.00\",arrowhead=none]\n" + "LOGICAL_OLAP_SCAN2 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n" + + "LOGICAL_OLAP_SCAN0 -> LOGICAL_OLAP_SCAN3 [label=\"1.00\",arrowhead=none]\n" + "}\n"; Assertions.assertEquals(dottyGraph, target); } diff --git a/regression-test/suites/nereids_syntax_p0/join_reorder_dphyper.groovy b/regression-test/suites/nereids_syntax_p0/join_reorder_dphyper.groovy new file mode 100644 index 0000000000..3146ac8945 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/join_reorder_dphyper.groovy @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("join_order_dphyper") { + sql 'set enable_nereids_planner=true' + sql 'set enable_fallback_to_original_planner=false' + + sql """ drop table if exists dphyper_store_sales;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_store_sales ( + ss_sold_date_sk bigint, + ss_customer_sk bigint, + ss_cdemo_sk bigint, + ss_hdemo_sk bigint, + ss_addr_sk bigint, + ss_store_sk bigint, + ss_ticket_number bigint + ) + DUPLICATE KEY(ss_sold_date_sk, ss_customer_sk) + DISTRIBUTED BY HASH(ss_customer_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_store_returns;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_store_returns ( + sr_ticket_number bigint + ) + DUPLICATE KEY(sr_ticket_number) + DISTRIBUTED BY HASH(sr_ticket_number) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_date_dim;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_date_dim ( + d_date_sk bigint + ) + DUPLICATE KEY(d_date_sk) + DISTRIBUTED BY HASH(d_date_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_store;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_store ( + s_store_sk bigint + ) + DUPLICATE KEY(s_store_sk) + DISTRIBUTED BY HASH(s_store_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_customer;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_customer ( + c_customer_sk bigint, + c_current_cdemo_sk bigint, + c_current_hdemo_sk bigint, + c_current_addr_sk bigint + ) + DUPLICATE KEY(c_customer_sk) + DISTRIBUTED BY HASH(c_customer_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_customer_demographics;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_customer_demographics ( + cd_demo_sk bigint, + cd_marital_status char(1) + ) + DUPLICATE KEY(cd_demo_sk) + DISTRIBUTED BY HASH(cd_demo_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_household_demographics;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_household_demographics ( + hd_demo_sk bigint + ) + DUPLICATE KEY(hd_demo_sk) + DISTRIBUTED BY HASH(hd_demo_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ drop table if exists dphyper_customer_address;""" + sql """ + CREATE TABLE IF NOT EXISTS dphyper_customer_address ( + ca_address_sk bigint + ) + DUPLICATE KEY(ca_address_sk) + DISTRIBUTED BY HASH(ca_address_sk) BUCKETS 3 + PROPERTIES ( + "replication_num" = "1" + ) + """ + + explain { + sql("""SELECT + count(*) + FROM + dphyper_store_sales + , dphyper_store_returns + , dphyper_date_dim d1 + , dphyper_store + , dphyper_customer + , dphyper_customer_demographics cd1 + , dphyper_customer_demographics cd2 + , dphyper_household_demographics hd1 + , dphyper_household_demographics hd2 + , dphyper_customer_address ad1 + , dphyper_customer_address ad2 + WHERE (ss_store_sk = s_store_sk) + AND (ss_sold_date_sk = d1.d_date_sk) + AND (ss_customer_sk = c_customer_sk) + AND (ss_cdemo_sk = cd1.cd_demo_sk) + AND (ss_hdemo_sk = hd1.hd_demo_sk) + AND (ss_addr_sk = ad1.ca_address_sk) + AND (ss_ticket_number = sr_ticket_number) + AND (c_current_cdemo_sk = cd2.cd_demo_sk) + AND (c_current_hdemo_sk = hd2.hd_demo_sk) + AND (c_current_addr_sk = ad2.ca_address_sk) + AND (cd1.cd_marital_status <> cd2.cd_marital_status);""") + notContains "VNESTED LOOP JOIN" + } +}