diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java
index 412c3c2308..96a8f71510 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.JoinUtils;
@@ -61,7 +62,11 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory {
.addAll(join.getHashJoinConjuncts())
.addAll(extractedHashJoinConjuncts)
.build();
- return new LogicalJoin<>(join.getJoinType(),
+ JoinType joinType = join.getJoinType();
+ if (joinType == JoinType.CROSS_JOIN && !combinedHashJoinConjuncts.isEmpty()) {
+ joinType = JoinType.INNER_JOIN;
+ }
+ return new LogicalJoin<>(joinType,
combinedHashJoinConjuncts,
remainedNonHashJoinConjuncts,
join.left(), join.right());
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
index 20a9550af3..9d1ecf84ab 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java
@@ -17,34 +17,43 @@
package org.apache.doris.nereids.rules.rewrite.logical;
-import org.apache.doris.common.Pair;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
-import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.PlanType;
+import org.apache.doris.nereids.trees.plans.logical.AbstractLogicalPlan;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
-import org.apache.doris.nereids.util.ExpressionUtils;
-import org.apache.doris.nereids.util.JoinUtils;
-import org.apache.doris.nereids.util.PlanUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
-import java.util.ArrayList;
-import java.util.HashSet;
import java.util.List;
-import java.util.Map;
+import java.util.Objects;
import java.util.Optional;
-import java.util.Set;
import java.util.stream.Collectors;
/**
* A MultiJoin represents a join of N inputs (NAry-Join).
* The regular Join represent strictly binary input (Binary-Join).
+ *
+ * One {@link MultiJoin} just contains one {@link JoinType} of SEMI/ANTI/OUTER Join.
+ *
+ * joinType is FULL OUTER JOIN, children.size() == 2.
+ * leftChild [FULL OUTER JOIN] rightChild.
+ *
+ * joinType is LEFT (OUTER/SEMI/ANTI) JOIN,
+ * children[0, last) [LEFT (OUTER/SEMI/ANTI) JOIN] lastChild.
+ * eg: MJ([LOJ] A, B, C, D) is {A B C} [LOJ] {D}.
+ *
+ * joinType is RIGHT (OUTER/SEMI/ANTI) JOIN,
+ * firstChild [RIGHT (OUTER/SEMI/ANTI) JOIN] children[1, last].
+ * eg: MJ([ROJ] A, B, C, D) is {A} [ROJ] {B C D}.
*/
-public class MultiJoin extends PlanVisitor {
+public class MultiJoin extends AbstractLogicalPlan {
/*
* topJoin
* / \ MultiJoin
@@ -52,154 +61,125 @@ public class MultiJoin extends PlanVisitor {
* / \ A B C
* A B
*/
- public final List joinInputs = new ArrayList<>();
- public final List conjunctsForAllHashJoins = new ArrayList<>();
- private List conjunctsKeepInFilter = new ArrayList<>();
- /**
- * reorderJoinsAccordingToConditions
- *
- * @return join or filter
- */
- public Optional reorderJoinsAccordingToConditions() {
- if (joinInputs.size() >= 2) {
- Plan root = reorderJoinsAccordingToConditions(joinInputs, conjunctsForAllHashJoins);
- return Optional.of(PlanUtils.filterOrSelf(conjunctsKeepInFilter, root));
- }
- return Optional.empty();
+ // Push predicates into it.
+ // But joinFilter shouldn't contain predicate which just contains one predicate like `T.key > 1`.
+ // Because these predicate should be pushdown.
+ private final List joinFilter;
+ // MultiJoin just contains one OUTER/SEMI/ANTI.
+ private final JoinType joinType;
+ // When contains one OUTER/SEMI/ANTI join, keep separately its condition.
+ private final List notInnerJoinConditions;
+
+ public MultiJoin(List inputs, List joinFilter, JoinType joinType,
+ List notInnerJoinConditions) {
+ super(PlanType.LOGICAL_MULTI_JOIN, inputs.toArray(new Plan[0]));
+ this.joinFilter = Objects.requireNonNull(joinFilter);
+ this.joinType = joinType;
+ this.notInnerJoinConditions = Objects.requireNonNull(notInnerJoinConditions);
}
- /**
- * Reorder join orders according to join conditions to eliminate cross join.
- *
- * Let's say we have input join tables: [t1, t2, t3] and
- * conjunctive predicates: [t1.id=t3.id, t2.id=t3.id]
- * The input join for t1 and t2 is cross join.
- *
- * The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3].
- * Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination
- * of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions.
- *
- * As a result, Join(t1, t3) is an inner join.
- * Then the logic is applied to the rest of [Join(t1, t3), t2] recursively.
- */
- private Plan reorderJoinsAccordingToConditions(List joinInputs, List conjuncts) {
- if (joinInputs.size() == 2) {
- Pair, List> pair = JoinUtils.extractExpressionForHashTable(
- joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
- joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
- conjuncts);
- List joinConditions = pair.first;
- conjunctsKeepInFilter = pair.second;
-
- return new LogicalJoin<>(JoinType.INNER_JOIN,
- ExpressionUtils.EMPTY_CONDITION,
- joinConditions,
- joinInputs.get(0), joinInputs.get(1));
- }
- // input size >= 3;
- Plan left = joinInputs.get(0);
- List candidate = joinInputs.subList(1, joinInputs.size());
-
- List leftOutput = left.getOutput();
- Optional rightOpt = candidate.stream().filter(right -> {
- List rightOutput = right.getOutput();
-
- Set joinOutput = getJoinOutput(left, right);
- Optional joinCond = conjuncts.stream()
- .filter(expr -> {
- Set exprInputSlots = expr.getInputSlots();
- if (exprInputSlots.isEmpty()) {
- return false;
- }
-
- if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) {
- return false;
- }
-
- if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) {
- return false;
- }
-
- return joinOutput.containsAll(exprInputSlots);
- }).findFirst();
- return joinCond.isPresent();
- }).findFirst();
-
- Plan right = rightOpt.orElseGet(() -> candidate.get(1));
- Pair, List> pair = JoinUtils.extractExpressionForHashTable(
- joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
- joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
- conjuncts);
- List joinConditions = pair.first;
- List nonJoinConditions = pair.second;
- LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION,
- joinConditions, left, right);
-
- List newInputs = new ArrayList<>();
- newInputs.add(join);
- newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList()));
- return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions);
+ public JoinType getJoinType() {
+ return joinType;
}
- private Map> splitConjuncts(List conjuncts, Set slots) {
- return conjuncts.stream().collect(Collectors.partitioningBy(
- // TODO: support non equal to conditions.
- expr -> expr instanceof EqualTo && slots.containsAll(expr.getInputSlots())));
+ public List getJoinFilter() {
+ return joinFilter;
}
- private Set getJoinOutput(Plan left, Plan right) {
- HashSet joinOutput = new HashSet<>();
- joinOutput.addAll(left.getOutput());
- joinOutput.addAll(right.getOutput());
- return joinOutput;
+ public List getNotInnerJoinConditions() {
+ return notInnerJoinConditions;
}
@Override
- public Void visit(Plan plan, Void context) {
- for (Plan child : plan.children()) {
- child.accept(this, context);
- }
- return null;
+ public MultiJoin withChildren(List children) {
+ return new MultiJoin(children, joinFilter, joinType, notInnerJoinConditions);
}
@Override
- public Void visitLogicalFilter(LogicalFilter extends Plan> filter, Void context) {
- Plan child = filter.child();
- if (child instanceof LogicalJoin) {
- conjunctsForAllHashJoins.addAll(ExpressionUtils.extractConjunction(filter.getPredicates()));
+ public List computeOutput() {
+ Builder builder = ImmutableList.builder();
+
+ if (joinType.isInnerOrCrossJoin()) {
+ // INNER/CROSS
+ for (Plan child : children) {
+ builder.addAll(child.getOutput());
+ }
+ return builder.build();
}
- child.accept(this, context);
- return null;
+ // FULL OUTER JOIN
+ if (joinType.isFullOuterJoin()) {
+ for (Plan child : children) {
+ builder.addAll(child.getOutput().stream()
+ .map(o -> o.withNullable(true))
+ .collect(Collectors.toList()));
+ }
+ return builder.build();
+ }
+
+ // RIGHT OUTER | RIGHT_SEMI/ANTI
+ if (joinType.isRightJoin()) {
+ // RIGHT OUTER
+ if (joinType.isRightOuterJoin()) {
+ builder.addAll(children.get(0).getOutput().stream()
+ .map(o -> o.withNullable(true))
+ .collect(Collectors.toList()));
+ }
+ for (int i = 1; i < children.size(); i++) {
+ builder.addAll(children.get(i).getOutput());
+ }
+
+ return builder.build();
+ }
+
+ // LEFT OUTER | LEFT_SEMI/ANTI
+ if (joinType.isLeftJoin()) {
+ for (int i = 0; i < children.size() - 1; i++) {
+ builder.addAll(children.get(i).getOutput());
+ }
+ // LEFT OUTER
+ if (joinType.isLeftOuterJoin()) {
+ builder.addAll(children.get(arity() - 1).getOutput().stream()
+ .map(o -> o.withNullable(true))
+ .collect(Collectors.toList()));
+ }
+
+ return builder.build();
+ }
+
+ throw new RuntimeException("unreachable");
}
@Override
- public Void visitLogicalJoin(LogicalJoin extends Plan, ? extends Plan> join, Void context) {
- if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) {
- return null;
- }
+ public R accept(PlanVisitor visitor, C context) {
+ throw new RuntimeException("multiJoin can't invoke accept");
+ }
- join.left().accept(this, context);
- join.right().accept(this, context);
+ @Override
+ public List extends Expression> getExpressions() {
+ return new Builder()
+ .addAll(joinFilter)
+ .addAll(notInnerJoinConditions)
+ .build();
+ }
- conjunctsForAllHashJoins.addAll(join.getHashJoinConjuncts());
- conjunctsForAllHashJoins.addAll(join.getOtherJoinConjuncts());
+ @Override
+ public Plan withGroupExpression(Optional groupExpression) {
+ throw new RuntimeException("multiJoin can't invoke withGroupExpression");
+ }
- Plan leftChild = join.left();
- if (join.left() instanceof LogicalFilter) {
- leftChild = join.left().child(0);
- }
- if (leftChild instanceof GroupPlan) {
- joinInputs.add(join.left());
- }
- Plan rightChild = join.right();
- if (join.right() instanceof LogicalFilter) {
- rightChild = join.right().child(0);
- }
- if (rightChild instanceof GroupPlan) {
- joinInputs.add(join.right());
- }
- return null;
+ @Override
+ public Plan withLogicalProperties(Optional logicalProperties) {
+ throw new RuntimeException("multiJoin can't invoke withLogicalProperties");
+ }
+
+ @Override
+ public String toString() {
+ return Utils.toSqlString("MultiJoin",
+ "joinType", joinType,
+ "joinFilter", joinFilter,
+ "notInnerJoinConditions", notInnerJoinConditions
+ );
}
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
index 88ad3eff60..1cbdc370e2 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java
@@ -17,12 +17,31 @@
package org.apache.doris.nereids.rules.rewrite.logical;
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.plans.GroupPlan;
+import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.JoinUtils;
+import org.apache.doris.nereids.util.PlanUtils;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableList.Builder;
+import com.google.common.collect.Lists;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Try to eliminate cross join via finding join conditions in filters and change the join orders.
@@ -37,21 +56,308 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
* SELECT * FROM t1 JOIN t3 ON t1.id=t3.id JOIN t2 ON t2.id=t3.id
*
*
- * TODO: This is tested by SSB queries currently, add more `unit` test for this rule
- * when we have a plan building and comparing framework.
+ * Using the {@link MultiJoin} to complete this task.
+ * {Join cluster}: contain multiple join with filter inside.
+ *
+ * - {Join cluster} to MultiJoin
+ * - MultiJoin to {Join cluster}
+ *
*/
public class ReorderJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalFilter(subTree(LogicalJoin.class, LogicalFilter.class)).thenApply(ctx -> {
LogicalFilter filter = ctx.root;
- if (!ctx.cascadesContext.getConnectContext().getSessionVariable()
- .isEnableNereidsReorderToEliminateCrossJoin()) {
- return filter;
- }
- MultiJoin multiJoin = new MultiJoin();
- filter.accept(multiJoin, null);
- return multiJoin.reorderJoinsAccordingToConditions().orElse(filter);
+
+ Plan plan = joinToMultiJoin(filter);
+ Preconditions.checkState(plan instanceof MultiJoin);
+ MultiJoin multiJoin = (MultiJoin) plan;
+ Plan after = multiJoinToJoin(multiJoin);
+ return after;
}).toRule(RuleType.REORDER_JOIN);
}
+
+ /**
+ * Recursively convert to
+ * {@link LogicalJoin} or
+ * {@link LogicalFilter}--{@link LogicalJoin}
+ * --> {@link MultiJoin}
+ */
+ public Plan joinToMultiJoin(Plan plan) {
+ // subtree can't specify the end of Pattern. so end can be GroupPlan or Filter
+ if (plan instanceof GroupPlan
+ || (plan instanceof LogicalFilter && plan.child(0) instanceof GroupPlan)) {
+ return plan;
+ }
+
+ List inputs = Lists.newArrayList();
+ List joinFilter = Lists.newArrayList();
+ List notInnerJoinConditions = Lists.newArrayList();
+
+ LogicalJoin, ?> join;
+ // Implicit rely on {rule: MergeFilters}, so don't exist filter--filter--join.
+ if (plan instanceof LogicalFilter) {
+ LogicalFilter> filter = (LogicalFilter>) plan;
+ joinFilter.addAll(ExpressionUtils.extractConjunction(filter.getPredicates()));
+ join = (LogicalJoin, ?>) filter.child();
+ } else {
+ join = (LogicalJoin, ?>) plan;
+ }
+
+ if (join.getJoinType().isInnerOrCrossJoin()) {
+ joinFilter.addAll(join.getHashJoinConjuncts());
+ joinFilter.addAll(join.getOtherJoinConjuncts());
+ } else {
+ notInnerJoinConditions.addAll(join.getHashJoinConjuncts());
+ notInnerJoinConditions.addAll(join.getOtherJoinConjuncts());
+ }
+
+ // recursively convert children.
+ Plan left = joinToMultiJoin(join.left());
+ Plan right = joinToMultiJoin(join.right());
+
+ boolean changeLeft = join.getJoinType().isRightJoin()
+ || join.getJoinType().isFullOuterJoin();
+ if (canCombine(left, changeLeft)) {
+ MultiJoin leftMultiJoin = (MultiJoin) left;
+ inputs.addAll(leftMultiJoin.children());
+ joinFilter.addAll(leftMultiJoin.getJoinFilter());
+ } else {
+ inputs.add(left);
+ }
+
+ boolean changeRight = join.getJoinType().isLeftJoin()
+ || join.getJoinType().isFullOuterJoin();
+ if (canCombine(right, changeRight)) {
+ MultiJoin rightMultiJoin = (MultiJoin) right;
+ inputs.addAll(rightMultiJoin.children());
+ joinFilter.addAll(rightMultiJoin.getJoinFilter());
+ } else {
+ inputs.add(right);
+ }
+
+ return new MultiJoin(
+ inputs,
+ joinFilter,
+ join.getJoinType(),
+ notInnerJoinConditions);
+ }
+
+ /**
+ * Recursively convert to
+ * {@link MultiJoin}
+ * -->
+ * {@link LogicalJoin} or
+ * {@link LogicalFilter}--{@link LogicalJoin}
+ *
+ * When all input is CROSS/Inner Join, all join will be flattened.
+ * Otherwise, we will split {join cluster} into multiple {@link MultiJoin}.
+ *
+ * Here are examples of the {@link MultiJoin}s constructed after this rules has been applied.
+ *
+ * simple example:
+ *
+ * - A JOIN B --> MJ(A, B)
+ *
- A JOIN B JOIN C JOIN D --> MJ(A, B, C, D)
+ *
- A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([LOJ/LSJ/LAJ]A, B)
+ *
- A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([ROJ/RSJ/RAJ]A, B)
+ *
- A FULL JOIN B --> MJ[FOJ](A, B)
+ *
+ *
+ *
+ * complex example:
+ *
+ * - A LEFT OUTER JOIN (B JOIN C) --> MJ([LOJ]A, MJ(B, C)))
+ *
- (A JOIN B) LEFT JOIN C --> MJ(A, B, C)
+ *
- (A LEFT OUTER JOIN B) JOIN C --> MJ(MJ(A, B), C)
+ *
- A LEFT JOIN (B FULL JOIN C) --> MJ(A, MJ[full](B, C))
+ *
- (A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) --> MJ[full](MJ(A, B), MJ(C, D))
+ *
+ *
+ * more complex example:
+ *
+ * - A JOIN B JOIN C LEFT JOIN D --> MJ([LOJ]A, B, C, D)
+ *
- A JOIN B JOIN C LEFT JOIN (D JOIN F) --> MJ([LOJ]A, B, C, MJ(D, F))
+ *
- A RIGHT JOIN (B JOIN C JOIN D)--> MJ([ROJ]A, B, C, D)
+ *
- A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
+ *
+ *
+ *
+ * Graphic presentation:
+ * A JOIN B JOIN C LEFT JOIN D JOIN F
+ * left left│
+ * A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF)
+ *
+ * A JOIN B RIGHT JOIN C JOIN D JOIN F
+ * right │right
+ * A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F)
+ *
+ * (A JOIN B JOIN C) FULL JOIN (D JOIN F)
+ * full │
+ * A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F))
+ *
+ */
+ public Plan multiJoinToJoin(MultiJoin multiJoin) {
+ if (multiJoin.arity() == 1) {
+ return PlanUtils.filterOrSelf(multiJoin.getJoinFilter(), multiJoin.child(0));
+ }
+
+ Builder builder = ImmutableList.builder();
+ // recursively hanlde multiJoin children.
+ for (Plan child : multiJoin.children()) {
+ if (child instanceof MultiJoin) {
+ MultiJoin childMultiJoin = (MultiJoin) child;
+ builder.add(multiJoinToJoin(childMultiJoin));
+ } else {
+ builder.add(child);
+ }
+ }
+ MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build());
+
+ if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) {
+ List leftFilter = Lists.newArrayList();
+ List rightFilter = Lists.newArrayList();
+ List remainingFilter = Lists.newArrayList();
+ Plan left;
+ Plan right;
+ if (multiJoinHandleChildren.getJoinType().isLeftJoin()) {
+ left = multiJoinToJoin(new MultiJoin(
+ multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
+ leftFilter,
+ JoinType.INNER_JOIN,
+ ExpressionUtils.EMPTY_CONDITION));
+ right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
+ } else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
+ left = multiJoinHandleChildren.child(0);
+ right = multiJoinToJoin(new MultiJoin(
+ multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
+ rightFilter,
+ JoinType.INNER_JOIN,
+ ExpressionUtils.EMPTY_CONDITION));
+ } else {
+ left = multiJoinToJoin(new MultiJoin(
+ multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
+ leftFilter,
+ JoinType.INNER_JOIN,
+ ExpressionUtils.EMPTY_CONDITION));
+ right = multiJoinToJoin(new MultiJoin(
+ multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
+ rightFilter,
+ JoinType.INNER_JOIN,
+ ExpressionUtils.EMPTY_CONDITION));
+ }
+
+ // split filter
+ for (Expression expr : multiJoinHandleChildren.getJoinFilter()) {
+ Set exprInputSlots = expr.getInputSlots();
+ Preconditions.checkState(!exprInputSlots.isEmpty());
+
+ if (left.getOutputSet().containsAll(exprInputSlots)) {
+ leftFilter.add(expr);
+ } else if (right.getOutputSet().containsAll(exprInputSlots)) {
+ rightFilter.add(expr);
+ } else if (multiJoin.getOutputSet().containsAll(exprInputSlots)) {
+ remainingFilter.add(expr);
+ } else {
+ NereidsPlanner.LOG.error("invalid expression, exist slot that multiJoin don't contains.");
+ throw new RuntimeException("invalid expression, exist slot that multiJoin don't contains.");
+ }
+ }
+
+ return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>(
+ multiJoinHandleChildren.getJoinType(),
+ ExpressionUtils.EMPTY_CONDITION,
+ multiJoinHandleChildren.getNotInnerJoinConditions(),
+ PlanUtils.filterOrSelf(leftFilter, left), PlanUtils.filterOrSelf(rightFilter, right)));
+ }
+
+ // following this multiJoin just contain INNER/CROSS.
+ List joinFilter = multiJoinHandleChildren.getJoinFilter();
+
+ Plan left = multiJoinHandleChildren.child(0);
+ List candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity());
+
+ LogicalJoin extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter);
+ List newInputs = Lists.newArrayList();
+ newInputs.add(join);
+ newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList()));
+
+ joinFilter.removeAll(join.getHashJoinConjuncts());
+ joinFilter.removeAll(join.getOtherJoinConjuncts());
+ // TODO(wj): eliminate this recursion.
+ return multiJoinToJoin(new MultiJoin(
+ newInputs,
+ joinFilter,
+ JoinType.INNER_JOIN,
+ ExpressionUtils.EMPTY_CONDITION));
+ }
+
+ /**
+ * Returns whether an input can be merged without changing semantics.
+ *
+ * @param input input into a MultiJoin or (GroupPlan|LogicalFilter)
+ * @param changeChildren generate nullable or one side child not exist.
+ * @return true if the input can be combined into a parent MultiJoin
+ */
+ private static boolean canCombine(Plan input, boolean changeChildren) {
+ return input instanceof MultiJoin
+ && ((MultiJoin) input).getJoinType().isInnerOrCrossJoin()
+ && !changeChildren;
+ }
+
+ private Set getJoinOutput(Plan left, Plan right) {
+ HashSet joinOutput = new HashSet<>();
+ joinOutput.addAll(left.getOutput());
+ joinOutput.addAll(right.getOutput());
+ return joinOutput;
+ }
+
+ /**
+ * Find hash condition from joinFilter
+ * Get InnerJoin from left, right from [candidates].
+ *
+ * @return InnerJoin or CrossJoin{left, last of [candidates]}
+ */
+ private LogicalJoin extends Plan, ? extends Plan> findInnerJoin(Plan left, List candidates,
+ List joinFilter) {
+ Set leftOutputSet = left.getOutputSet();
+ for (int i = 0; i < candidates.size(); i++) {
+ Plan candidate = candidates.get(i);
+ Set rightOutputSet = candidate.getOutputSet();
+
+ Set joinOutput = getJoinOutput(left, candidate);
+
+ List currentJoinFilter = joinFilter.stream()
+ .filter(expr -> {
+ Set exprInputSlots = expr.getInputSlots();
+ Preconditions.checkState(exprInputSlots.size() > 1,
+ "Predicate like table.col > 1 must have pushdown.");
+ if (leftOutputSet.containsAll(exprInputSlots)) {
+ return false;
+ }
+ if (rightOutputSet.containsAll(exprInputSlots)) {
+ return false;
+ }
+
+ return joinOutput.containsAll(exprInputSlots);
+ }).collect(Collectors.toList());
+
+ Pair, List> pair = JoinUtils.extractExpressionForHashTable(
+ left.getOutput(), candidate.getOutput(), currentJoinFilter);
+ List hashJoinConditions = pair.first;
+ List otherJoinConditions = pair.second;
+ if (!hashJoinConditions.isEmpty()) {
+ return new LogicalJoin<>(JoinType.INNER_JOIN,
+ hashJoinConditions, otherJoinConditions,
+ left, candidate);
+ }
+
+ if (i == candidates.size() - 1) {
+ return new LogicalJoin<>(JoinType.CROSS_JOIN,
+ hashJoinConditions, otherJoinConditions,
+ left, candidate);
+ }
+ }
+ throw new RuntimeException("findInnerJoin: can't reach here");
+ }
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java
index 764eecd154..b9e3d909c1 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java
@@ -92,6 +92,14 @@ public enum JoinType {
return this == INNER_JOIN;
}
+ public final boolean isInnerOrCrossJoin() {
+ return this == INNER_JOIN || this == CROSS_JOIN;
+ }
+
+ public final boolean isLeftJoin() {
+ return this == LEFT_OUTER_JOIN || this == LEFT_ANTI_JOIN || this == LEFT_SEMI_JOIN;
+ }
+
public final boolean isRightJoin() {
return this == RIGHT_OUTER_JOIN || this == RIGHT_ANTI_JOIN || this == RIGHT_SEMI_JOIN;
}
@@ -108,6 +116,14 @@ public enum JoinType {
return this == RIGHT_OUTER_JOIN;
}
+ public final boolean isLeftSemiOrAntiJoin() {
+ return this == LEFT_SEMI_JOIN || this == LEFT_ANTI_JOIN;
+ }
+
+ public final boolean isRightSemiOrAntiJoin() {
+ return this == RIGHT_SEMI_JOIN || this == RIGHT_ANTI_JOIN;
+ }
+
public final boolean isSemiOrAntiJoin() {
return this == LEFT_SEMI_JOIN || this == RIGHT_SEMI_JOIN || this == LEFT_ANTI_JOIN || this == RIGHT_ANTI_JOIN;
}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java
index 9ef8a7eae9..5fb4564b32 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java
@@ -42,6 +42,7 @@ public enum PlanType {
LOGICAL_SELECT_HINT,
LOGICAL_ASSERT_NUM_ROWS,
LOGICAL_HAVING,
+ LOGICAL_MULTI_JOIN,
GROUP_PLAN,
// physical plan
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java
index 9e5f624fec..0029c205f7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java
@@ -102,7 +102,7 @@ public class LogicalFilter extends LogicalUnary R accept(PlanVisitor visitor, C context) {
- return visitor.visitLogicalFilter((LogicalFilter) this, context);
+ return visitor.visitLogicalFilter(this, context);
}
@Override
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
index 7aa28fe259..4a2bcb1cb8 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java
@@ -17,26 +17,15 @@
package org.apache.doris.nereids.datasets.ssb;
-import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.plans.Plan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
-import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
-import org.apache.doris.nereids.util.PlanRewriter;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
-import java.util.ArrayList;
import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-public class SSBJoinReorderTest extends SSBTestBase {
+public class SSBJoinReorderTest extends SSBTestBase implements PatternMatchSupported {
@Test
public void q4_1() {
test(
@@ -47,9 +36,11 @@ public class SSBJoinReorderTest extends SSBTestBase {
"(lo_suppkey = s_suppkey)",
"(lo_partkey = p_partkey)"
),
- ImmutableList.of("(((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
- + "= CAST('AMERICA' AS STRING))) AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) "
- + "OR (CAST(p_mfgr AS STRING) = CAST('MFGR#2' AS STRING))))")
+ ImmutableList.of(
+ "(CAST(c_region AS STRING) = 'AMERICA')",
+ "(CAST(s_region AS STRING) = 'AMERICA')",
+ "((CAST(p_mfgr AS STRING) = 'MFGR#1') OR (CAST(p_mfgr AS STRING) = 'MFGR#2'))"
+ )
);
}
@@ -64,10 +55,11 @@ public class SSBJoinReorderTest extends SSBTestBase {
"(lo_partkey = p_partkey)"
),
ImmutableList.of(
- "((((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
- + "= CAST('AMERICA' AS STRING))) AND ((d_year = 1997) OR (d_year = 1998))) "
- + "AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) OR (CAST(p_mfgr AS STRING) "
- + "= CAST('MFGR#2' AS STRING))))")
+ "((d_year = 1997) OR (d_year = 1998))",
+ "(CAST(c_region AS STRING) = 'AMERICA')",
+ "(CAST(s_region AS STRING) = 'AMERICA')",
+ "((CAST(p_mfgr AS STRING) = 'MFGR#1') OR (CAST(p_mfgr AS STRING) = 'MFGR#2'))"
+ )
);
}
@@ -81,94 +73,31 @@ public class SSBJoinReorderTest extends SSBTestBase {
"(lo_suppkey = s_suppkey)",
"(lo_partkey = p_partkey)"
),
- ImmutableList.of("(((CAST(s_nation AS STRING) = CAST('UNITED STATES' AS STRING)) AND ((d_year = 1997) "
- + "OR (d_year = 1998))) AND (CAST(p_category AS STRING) = CAST('MFGR#14' AS STRING)))")
+ ImmutableList.of(
+ "((d_year = 1997) OR (d_year = 1998))",
+ "(CAST(s_nation AS STRING) = 'UNITED STATES')",
+ "(CAST(p_category AS STRING) = 'MFGR#14')"
+ )
);
}
private void test(String sql, List expectJoinConditions, List expectFilterPredicates) {
- LogicalPlan analyzed = analyze(sql);
- LogicalPlan plan = testJoinReorder(analyzed);
- System.out.println(plan.treeString());
- new PlanChecker(expectJoinConditions, expectFilterPredicates).check(plan);
- }
+ PlanChecker planChecker = PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .printlnTree();
- private LogicalPlan testJoinReorder(LogicalPlan plan) {
- return (LogicalPlan) PlanRewriter.topDownRewrite(plan, connectContext, new ReorderJoin());
- }
-
- private static class PlanChecker extends PlanVisitor {
- private final List joinInputs = new ArrayList<>();
- private final List joins = new ArrayList<>();
- private final List filters = new ArrayList<>();
- // TODO: it's tricky to compare expression by string, use a graceful manner to do this in the future.
- private final List expectJoinConditions;
- private final List expectFilterPredicates;
-
- public PlanChecker(List expectJoinConditions, List expectFilterPredicates) {
- this.expectJoinConditions = expectJoinConditions;
- this.expectFilterPredicates = expectFilterPredicates;
+ for (String expectJoinCondition : expectJoinConditions) {
+ planChecker.matches(
+ innerLogicalJoin().when(
+ join -> join.getHashJoinConjuncts().get(0).toSql().equals(expectJoinCondition))
+ );
}
- public void check(Plan plan) {
- plan.accept(this, new Context(null));
-
- // check join table orders
- Assertions.assertEquals(
- ImmutableList.of("dates", "lineorder", "customer", "supplier", "part"),
- joinInputs.stream().map(p -> p.getTable().getName()).collect(Collectors.toList()));
-
- // check join conditions
- List actualJoinConditions = joins.stream().map(j -> {
- Optional condition = j.getOnClauseCondition();
- return condition.map(Expression::toSql).orElse("");
- }).collect(Collectors.toList());
-
- Assertions.assertEquals(expectJoinConditions, actualJoinConditions);
-
- // check filter predicates
- List actualFilterPredicates = filters.stream()
- .map(f -> f.getPredicates().toSql()).collect(Collectors.toList());
- Assertions.assertEquals(expectFilterPredicates, actualFilterPredicates);
- }
-
- @Override
- public Void visit(Plan plan, Context context) {
- for (Plan child : plan.children()) {
- child.accept(this, new Context(plan));
- }
- return null;
- }
-
- @Override
- public Void visitLogicalRelation(LogicalRelation relation, Context context) {
- if (context.parent instanceof LogicalJoin) {
- joinInputs.add(relation);
- }
- return null;
- }
-
- @Override
- public Void visitLogicalFilter(LogicalFilter extends Plan> filter, Context context) {
- filters.add(filter);
- filter.child().accept(this, new Context(filter));
- return null;
- }
-
- @Override
- public Void visitLogicalJoin(LogicalJoin extends Plan, ? extends Plan> join, Context context) {
- join.left().accept(this, new Context(join));
- join.right().accept(this, new Context(join));
- joins.add(join);
- return null;
- }
- }
-
- private static class Context {
- public final Plan parent;
-
- public Context(Plan parent) {
- this.parent = parent;
+ for (String expectFilterPredicate : expectFilterPredicates) {
+ planChecker.matches(
+ logicalFilter().when(filter -> filter.getPredicates().toSql().equals(expectFilterPredicate))
+ );
}
}
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/JoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/JoinReorderTest.java
deleted file mode 100644
index e0806564ef..0000000000
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/JoinReorderTest.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.rewrite.logical;
-
-import org.apache.doris.nereids.analyzer.UnboundRelation;
-import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
-import org.apache.doris.nereids.trees.plans.JoinType;
-import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
-import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
-import org.apache.doris.nereids.util.MemoTestUtils;
-import org.apache.doris.nereids.util.PatternMatchSupported;
-import org.apache.doris.nereids.util.PlanChecker;
-
-import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Test;
-
-public class JoinReorderTest implements PatternMatchSupported {
-
- /**
- * To test do not throw unexpected exception when join type is not inner or cross.
- */
- @Test
- public void testWithOuterJoin() {
- UnboundRelation relation1 = new UnboundRelation(Lists.newArrayList("db", "table1"));
- UnboundRelation relation2 = new UnboundRelation(Lists.newArrayList("db", "table2"));
- LogicalJoin outerJoin = new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN, relation1, relation2);
- LogicalFilter logicalFilter = new LogicalFilter<>(BooleanLiteral.FALSE, outerJoin);
-
- PlanChecker.from(MemoTestUtils.createConnectContext(), logicalFilter)
- .applyBottomUp(new ReorderJoin())
- .matches(logicalFilter(leftOuterLogicalJoin(unboundRelation(), unboundRelation())));
- }
-}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
new file mode 100644
index 0000000000..ffb7e16510
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java
@@ -0,0 +1,183 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.common.Pair;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+
+class ReorderJoinTest implements PatternMatchSupported {
+
+ private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+ private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
+ private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
+ private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
+
+ @Test
+ public void testLeftOuterJoin() {
+ ImmutableList plans = ImmutableList.of(
+ new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build(),
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build()
+ );
+
+ check(plans);
+ }
+
+ @Test
+ public void testRightOuterJoin() {
+ ImmutableList plans = ImmutableList.of(
+ new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build(),
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build()
+ );
+
+ check(plans);
+ }
+
+ @Test
+ public void testLeftSemiJoin() {
+ ImmutableList plans = ImmutableList.of(
+ new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build(),
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .hashJoinUsing(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build()
+ );
+
+ check(plans);
+ }
+
+ @Test
+ public void testRightSemiJoin() {
+ ImmutableList plans = ImmutableList.of(
+ new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0)))
+ .build(),
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build()
+ );
+
+ check(plans);
+ }
+
+ @Test
+ public void testFullOuterJoin() {
+ ImmutableList plans = ImmutableList.of(
+ new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0))
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build(),
+ new LogicalPlanBuilder(scan1)
+ .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
+ .hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0))
+ .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
+ .build()
+ );
+
+ check(plans);
+ }
+
+ public void check(List plans) {
+ for (LogicalPlan plan : plans) {
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyBottomUp(new ReorderJoin())
+ .matchesFromRoot(
+ logicalJoin(
+ logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
+ leafPlan()
+ ).whenNot(join -> join.getJoinType().isCrossJoin())
+ )
+ .printlnTree();
+ }
+ }
+
+ /**
+ * join
+ * crossjoin / \
+ * / \ join D
+ * innerjoin innerjoin ──► / \
+ * / \ / \ join C
+ * A B C D / \
+ * A B
+ */
+ @Test
+ public void testInnerOrCrossJoin() {
+ LogicalPlan leftJoin = new LogicalPlanBuilder(scan1)
+ .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .build();
+ LogicalPlan rightJoin = new LogicalPlanBuilder(scan3)
+ .hashJoinUsing(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
+ .build();
+
+ LogicalPlan plan = new LogicalPlanBuilder(leftJoin)
+ .hashJoinEmptyOn(rightJoin, JoinType.CROSS_JOIN)
+ .filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)))
+ .build();
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
+ .applyBottomUp(new ReorderJoin())
+ .matchesFromRoot(
+ logicalJoin(
+ logicalJoin(
+ logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
+ leafPlan()
+ ).whenNot(join -> join.getJoinType().isCrossJoin()),
+ leafPlan()
+ ).whenNot(join -> join.getJoinType().isCrossJoin())
+ )
+ .printlnTree();
+ }
+}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
similarity index 59%
rename from fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java
rename to fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
index 2d958a613b..4a52796734 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java
@@ -23,9 +23,12 @@ import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
+import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
-public class SqlTest extends TestWithFeService implements PatternMatchSupported {
+import java.util.List;
+
+public class MultiJoinTest extends TestWithFeService implements PatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
@@ -68,17 +71,44 @@ public class SqlTest extends TestWithFeService implements PatternMatchSupported
}
@Test
- void testSql() {
- // String sql = "SELECT *"
- // + " FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id"
- // + " WHERE T1.id = T2.id";
- String sql = "SELECT *"
- + " FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1"
- + " WHERE T1.id = T2.id";
- PlanChecker.from(connectContext)
- .analyze(sql)
- .applyTopDown(new ReorderJoin())
- .implement()
- .printlnTree();
+ void testMultiJoinEliminateCross() {
+ List sqls = ImmutableList.builder()
+ .add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id")
+ .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id")
+ .build();
+
+ for (String sql : sqls) {
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .applyBottomUp(new ReorderJoin())
+ .matches(
+ logicalJoin(
+ logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
+ leafPlan()
+ ).whenNot(join -> join.getJoinType().isCrossJoin())
+ )
+ .printlnTree();
+ }
+ }
+
+ @Test
+ void testMultiJoinExistCross() {
+ List sqls = ImmutableList.builder()
+ .add("SELECT * FROM T2 LEFT SEMI JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id > T2.id")
+ .build();
+
+ for (String sql : sqls) {
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .applyBottomUp(new ReorderJoin())
+ .matches(
+ logicalJoin(
+ logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
+ leafPlan()
+ ).when(join -> join.getJoinType().isCrossJoin())
+ .whenNot(join -> join.getOtherJoinConjuncts().isEmpty())
+ )
+ .printlnTree();
+ }
}
}
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java
index 94f5149bd3..219acbe669 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java
@@ -81,7 +81,6 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
= "select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 )";
private final List testSql = ImmutableList.of(
sql1, sql2, sql3, sql4, sql5, sql6, sql7, sql8, sql9, sql10
-
);
@Override