diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java index 412c3c2308..96a8f71510 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/FindHashConditionForJoin.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.JoinUtils; @@ -61,7 +62,11 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory { .addAll(join.getHashJoinConjuncts()) .addAll(extractedHashJoinConjuncts) .build(); - return new LogicalJoin<>(join.getJoinType(), + JoinType joinType = join.getJoinType(); + if (joinType == JoinType.CROSS_JOIN && !combinedHashJoinConjuncts.isEmpty()) { + joinType = JoinType.INNER_JOIN; + } + return new LogicalJoin<>(joinType, combinedHashJoinConjuncts, remainedNonHashJoinConjuncts, join.left(), join.right()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java index 20a9550af3..9d1ecf84ab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java @@ -17,34 +17,43 @@ package org.apache.doris.nereids.rules.rewrite.logical; -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.PlanType; +import org.apache.doris.nereids.trees.plans.logical.AbstractLogicalPlan; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; -import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.JoinUtils; -import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.Utils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; -import java.util.ArrayList; -import java.util.HashSet; import java.util.List; -import java.util.Map; +import java.util.Objects; import java.util.Optional; -import java.util.Set; import java.util.stream.Collectors; /** * A MultiJoin represents a join of N inputs (NAry-Join). * The regular Join represent strictly binary input (Binary-Join). + *

+ * One {@link MultiJoin} just contains one {@link JoinType} of SEMI/ANTI/OUTER Join. + *

+ * joinType is FULL OUTER JOIN, children.size() == 2. + * leftChild [FULL OUTER JOIN] rightChild. + *

+ * joinType is LEFT (OUTER/SEMI/ANTI) JOIN, + * children[0, last) [LEFT (OUTER/SEMI/ANTI) JOIN] lastChild. + * eg: MJ([LOJ] A, B, C, D) is {A B C} [LOJ] {D}. + *

+ * joinType is RIGHT (OUTER/SEMI/ANTI) JOIN, + * firstChild [RIGHT (OUTER/SEMI/ANTI) JOIN] children[1, last]. + * eg: MJ([ROJ] A, B, C, D) is {A} [ROJ] {B C D}. */ -public class MultiJoin extends PlanVisitor { +public class MultiJoin extends AbstractLogicalPlan { /* * topJoin * / \ MultiJoin @@ -52,154 +61,125 @@ public class MultiJoin extends PlanVisitor { * / \ A B C * A B */ - public final List joinInputs = new ArrayList<>(); - public final List conjunctsForAllHashJoins = new ArrayList<>(); - private List conjunctsKeepInFilter = new ArrayList<>(); - /** - * reorderJoinsAccordingToConditions - * - * @return join or filter - */ - public Optional reorderJoinsAccordingToConditions() { - if (joinInputs.size() >= 2) { - Plan root = reorderJoinsAccordingToConditions(joinInputs, conjunctsForAllHashJoins); - return Optional.of(PlanUtils.filterOrSelf(conjunctsKeepInFilter, root)); - } - return Optional.empty(); + // Push predicates into it. + // But joinFilter shouldn't contain predicate which just contains one predicate like `T.key > 1`. + // Because these predicate should be pushdown. + private final List joinFilter; + // MultiJoin just contains one OUTER/SEMI/ANTI. + private final JoinType joinType; + // When contains one OUTER/SEMI/ANTI join, keep separately its condition. + private final List notInnerJoinConditions; + + public MultiJoin(List inputs, List joinFilter, JoinType joinType, + List notInnerJoinConditions) { + super(PlanType.LOGICAL_MULTI_JOIN, inputs.toArray(new Plan[0])); + this.joinFilter = Objects.requireNonNull(joinFilter); + this.joinType = joinType; + this.notInnerJoinConditions = Objects.requireNonNull(notInnerJoinConditions); } - /** - * Reorder join orders according to join conditions to eliminate cross join. - *

- * Let's say we have input join tables: [t1, t2, t3] and - * conjunctive predicates: [t1.id=t3.id, t2.id=t3.id] - * The input join for t1 and t2 is cross join. - *

- * The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3]. - * Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination - * of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions. - *

- * As a result, Join(t1, t3) is an inner join. - * Then the logic is applied to the rest of [Join(t1, t3), t2] recursively. - */ - private Plan reorderJoinsAccordingToConditions(List joinInputs, List conjuncts) { - if (joinInputs.size() == 2) { - Pair, List> pair = JoinUtils.extractExpressionForHashTable( - joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), - joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), - conjuncts); - List joinConditions = pair.first; - conjunctsKeepInFilter = pair.second; - - return new LogicalJoin<>(JoinType.INNER_JOIN, - ExpressionUtils.EMPTY_CONDITION, - joinConditions, - joinInputs.get(0), joinInputs.get(1)); - } - // input size >= 3; - Plan left = joinInputs.get(0); - List candidate = joinInputs.subList(1, joinInputs.size()); - - List leftOutput = left.getOutput(); - Optional rightOpt = candidate.stream().filter(right -> { - List rightOutput = right.getOutput(); - - Set joinOutput = getJoinOutput(left, right); - Optional joinCond = conjuncts.stream() - .filter(expr -> { - Set exprInputSlots = expr.getInputSlots(); - if (exprInputSlots.isEmpty()) { - return false; - } - - if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) { - return false; - } - - if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) { - return false; - } - - return joinOutput.containsAll(exprInputSlots); - }).findFirst(); - return joinCond.isPresent(); - }).findFirst(); - - Plan right = rightOpt.orElseGet(() -> candidate.get(1)); - Pair, List> pair = JoinUtils.extractExpressionForHashTable( - joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), - joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), - conjuncts); - List joinConditions = pair.first; - List nonJoinConditions = pair.second; - LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION, - joinConditions, left, right); - - List newInputs = new ArrayList<>(); - newInputs.add(join); - newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList())); - return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions); + public JoinType getJoinType() { + return joinType; } - private Map> splitConjuncts(List conjuncts, Set slots) { - return conjuncts.stream().collect(Collectors.partitioningBy( - // TODO: support non equal to conditions. - expr -> expr instanceof EqualTo && slots.containsAll(expr.getInputSlots()))); + public List getJoinFilter() { + return joinFilter; } - private Set getJoinOutput(Plan left, Plan right) { - HashSet joinOutput = new HashSet<>(); - joinOutput.addAll(left.getOutput()); - joinOutput.addAll(right.getOutput()); - return joinOutput; + public List getNotInnerJoinConditions() { + return notInnerJoinConditions; } @Override - public Void visit(Plan plan, Void context) { - for (Plan child : plan.children()) { - child.accept(this, context); - } - return null; + public MultiJoin withChildren(List children) { + return new MultiJoin(children, joinFilter, joinType, notInnerJoinConditions); } @Override - public Void visitLogicalFilter(LogicalFilter filter, Void context) { - Plan child = filter.child(); - if (child instanceof LogicalJoin) { - conjunctsForAllHashJoins.addAll(ExpressionUtils.extractConjunction(filter.getPredicates())); + public List computeOutput() { + Builder builder = ImmutableList.builder(); + + if (joinType.isInnerOrCrossJoin()) { + // INNER/CROSS + for (Plan child : children) { + builder.addAll(child.getOutput()); + } + return builder.build(); } - child.accept(this, context); - return null; + // FULL OUTER JOIN + if (joinType.isFullOuterJoin()) { + for (Plan child : children) { + builder.addAll(child.getOutput().stream() + .map(o -> o.withNullable(true)) + .collect(Collectors.toList())); + } + return builder.build(); + } + + // RIGHT OUTER | RIGHT_SEMI/ANTI + if (joinType.isRightJoin()) { + // RIGHT OUTER + if (joinType.isRightOuterJoin()) { + builder.addAll(children.get(0).getOutput().stream() + .map(o -> o.withNullable(true)) + .collect(Collectors.toList())); + } + for (int i = 1; i < children.size(); i++) { + builder.addAll(children.get(i).getOutput()); + } + + return builder.build(); + } + + // LEFT OUTER | LEFT_SEMI/ANTI + if (joinType.isLeftJoin()) { + for (int i = 0; i < children.size() - 1; i++) { + builder.addAll(children.get(i).getOutput()); + } + // LEFT OUTER + if (joinType.isLeftOuterJoin()) { + builder.addAll(children.get(arity() - 1).getOutput().stream() + .map(o -> o.withNullable(true)) + .collect(Collectors.toList())); + } + + return builder.build(); + } + + throw new RuntimeException("unreachable"); } @Override - public Void visitLogicalJoin(LogicalJoin join, Void context) { - if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) { - return null; - } + public R accept(PlanVisitor visitor, C context) { + throw new RuntimeException("multiJoin can't invoke accept"); + } - join.left().accept(this, context); - join.right().accept(this, context); + @Override + public List getExpressions() { + return new Builder() + .addAll(joinFilter) + .addAll(notInnerJoinConditions) + .build(); + } - conjunctsForAllHashJoins.addAll(join.getHashJoinConjuncts()); - conjunctsForAllHashJoins.addAll(join.getOtherJoinConjuncts()); + @Override + public Plan withGroupExpression(Optional groupExpression) { + throw new RuntimeException("multiJoin can't invoke withGroupExpression"); + } - Plan leftChild = join.left(); - if (join.left() instanceof LogicalFilter) { - leftChild = join.left().child(0); - } - if (leftChild instanceof GroupPlan) { - joinInputs.add(join.left()); - } - Plan rightChild = join.right(); - if (join.right() instanceof LogicalFilter) { - rightChild = join.right().child(0); - } - if (rightChild instanceof GroupPlan) { - joinInputs.add(join.right()); - } - return null; + @Override + public Plan withLogicalProperties(Optional logicalProperties) { + throw new RuntimeException("multiJoin can't invoke withLogicalProperties"); + } + + @Override + public String toString() { + return Utils.toSqlString("MultiJoin", + "joinType", joinType, + "joinFilter", joinFilter, + "notInnerJoinConditions", notInnerJoinConditions + ); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java index 88ad3eff60..1cbdc370e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java @@ -17,12 +17,31 @@ package org.apache.doris.nereids.rules.rewrite.logical; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.JoinUtils; +import org.apache.doris.nereids.util.PlanUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableList.Builder; +import com.google.common.collect.Lists; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; /** * Try to eliminate cross join via finding join conditions in filters and change the join orders. @@ -37,21 +56,308 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; * SELECT * FROM t1 JOIN t3 ON t1.id=t3.id JOIN t2 ON t2.id=t3.id * *

- * TODO: This is tested by SSB queries currently, add more `unit` test for this rule - * when we have a plan building and comparing framework. + * Using the {@link MultiJoin} to complete this task. + * {Join cluster}: contain multiple join with filter inside. + * */ public class ReorderJoin extends OneRewriteRuleFactory { @Override public Rule build() { return logicalFilter(subTree(LogicalJoin.class, LogicalFilter.class)).thenApply(ctx -> { LogicalFilter filter = ctx.root; - if (!ctx.cascadesContext.getConnectContext().getSessionVariable() - .isEnableNereidsReorderToEliminateCrossJoin()) { - return filter; - } - MultiJoin multiJoin = new MultiJoin(); - filter.accept(multiJoin, null); - return multiJoin.reorderJoinsAccordingToConditions().orElse(filter); + + Plan plan = joinToMultiJoin(filter); + Preconditions.checkState(plan instanceof MultiJoin); + MultiJoin multiJoin = (MultiJoin) plan; + Plan after = multiJoinToJoin(multiJoin); + return after; }).toRule(RuleType.REORDER_JOIN); } + + /** + * Recursively convert to + * {@link LogicalJoin} or + * {@link LogicalFilter}--{@link LogicalJoin} + * --> {@link MultiJoin} + */ + public Plan joinToMultiJoin(Plan plan) { + // subtree can't specify the end of Pattern. so end can be GroupPlan or Filter + if (plan instanceof GroupPlan + || (plan instanceof LogicalFilter && plan.child(0) instanceof GroupPlan)) { + return plan; + } + + List inputs = Lists.newArrayList(); + List joinFilter = Lists.newArrayList(); + List notInnerJoinConditions = Lists.newArrayList(); + + LogicalJoin join; + // Implicit rely on {rule: MergeFilters}, so don't exist filter--filter--join. + if (plan instanceof LogicalFilter) { + LogicalFilter filter = (LogicalFilter) plan; + joinFilter.addAll(ExpressionUtils.extractConjunction(filter.getPredicates())); + join = (LogicalJoin) filter.child(); + } else { + join = (LogicalJoin) plan; + } + + if (join.getJoinType().isInnerOrCrossJoin()) { + joinFilter.addAll(join.getHashJoinConjuncts()); + joinFilter.addAll(join.getOtherJoinConjuncts()); + } else { + notInnerJoinConditions.addAll(join.getHashJoinConjuncts()); + notInnerJoinConditions.addAll(join.getOtherJoinConjuncts()); + } + + // recursively convert children. + Plan left = joinToMultiJoin(join.left()); + Plan right = joinToMultiJoin(join.right()); + + boolean changeLeft = join.getJoinType().isRightJoin() + || join.getJoinType().isFullOuterJoin(); + if (canCombine(left, changeLeft)) { + MultiJoin leftMultiJoin = (MultiJoin) left; + inputs.addAll(leftMultiJoin.children()); + joinFilter.addAll(leftMultiJoin.getJoinFilter()); + } else { + inputs.add(left); + } + + boolean changeRight = join.getJoinType().isLeftJoin() + || join.getJoinType().isFullOuterJoin(); + if (canCombine(right, changeRight)) { + MultiJoin rightMultiJoin = (MultiJoin) right; + inputs.addAll(rightMultiJoin.children()); + joinFilter.addAll(rightMultiJoin.getJoinFilter()); + } else { + inputs.add(right); + } + + return new MultiJoin( + inputs, + joinFilter, + join.getJoinType(), + notInnerJoinConditions); + } + + /** + * Recursively convert to + * {@link MultiJoin} + * --> + * {@link LogicalJoin} or + * {@link LogicalFilter}--{@link LogicalJoin} + *

+ * When all input is CROSS/Inner Join, all join will be flattened. + * Otherwise, we will split {join cluster} into multiple {@link MultiJoin}. + *

+ * Here are examples of the {@link MultiJoin}s constructed after this rules has been applied. + *

+ * simple example: + *

    + *
  • A JOIN B --> MJ(A, B) + *
  • A JOIN B JOIN C JOIN D --> MJ(A, B, C, D) + *
  • A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([LOJ/LSJ/LAJ]A, B) + *
  • A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([ROJ/RSJ/RAJ]A, B) + *
  • A FULL JOIN B --> MJ[FOJ](A, B) + *
+ *

+ *

+ * complex example: + *

    + *
  • A LEFT OUTER JOIN (B JOIN C) --> MJ([LOJ]A, MJ(B, C))) + *
  • (A JOIN B) LEFT JOIN C --> MJ(A, B, C) + *
  • (A LEFT OUTER JOIN B) JOIN C --> MJ(MJ(A, B), C) + *
  • A LEFT JOIN (B FULL JOIN C) --> MJ(A, MJ[full](B, C)) + *
  • (A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) --> MJ[full](MJ(A, B), MJ(C, D)) + *
+ *

+ * more complex example: + *
    + *
  • A JOIN B JOIN C LEFT JOIN D --> MJ([LOJ]A, B, C, D) + *
  • A JOIN B JOIN C LEFT JOIN (D JOIN F) --> MJ([LOJ]A, B, C, MJ(D, F)) + *
  • A RIGHT JOIN (B JOIN C JOIN D)--> MJ([ROJ]A, B, C, D) + *
  • A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D)) + *
+ *

+ *

+ * Graphic presentation: + * A JOIN B JOIN C LEFT JOIN D JOIN F + * left left│ + * A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF) + *

+ * A JOIN B RIGHT JOIN C JOIN D JOIN F + * right │right + * A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F) + *

+ * (A JOIN B JOIN C) FULL JOIN (D JOIN F) + * full │ + * A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F)) + *

+ */ + public Plan multiJoinToJoin(MultiJoin multiJoin) { + if (multiJoin.arity() == 1) { + return PlanUtils.filterOrSelf(multiJoin.getJoinFilter(), multiJoin.child(0)); + } + + Builder builder = ImmutableList.builder(); + // recursively hanlde multiJoin children. + for (Plan child : multiJoin.children()) { + if (child instanceof MultiJoin) { + MultiJoin childMultiJoin = (MultiJoin) child; + builder.add(multiJoinToJoin(childMultiJoin)); + } else { + builder.add(child); + } + } + MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build()); + + if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) { + List leftFilter = Lists.newArrayList(); + List rightFilter = Lists.newArrayList(); + List remainingFilter = Lists.newArrayList(); + Plan left; + Plan right; + if (multiJoinHandleChildren.getJoinType().isLeftJoin()) { + left = multiJoinToJoin(new MultiJoin( + multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1), + leftFilter, + JoinType.INNER_JOIN, + ExpressionUtils.EMPTY_CONDITION)); + right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1); + } else if (multiJoinHandleChildren.getJoinType().isRightJoin()) { + left = multiJoinHandleChildren.child(0); + right = multiJoinToJoin(new MultiJoin( + multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()), + rightFilter, + JoinType.INNER_JOIN, + ExpressionUtils.EMPTY_CONDITION)); + } else { + left = multiJoinToJoin(new MultiJoin( + multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1), + leftFilter, + JoinType.INNER_JOIN, + ExpressionUtils.EMPTY_CONDITION)); + right = multiJoinToJoin(new MultiJoin( + multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()), + rightFilter, + JoinType.INNER_JOIN, + ExpressionUtils.EMPTY_CONDITION)); + } + + // split filter + for (Expression expr : multiJoinHandleChildren.getJoinFilter()) { + Set exprInputSlots = expr.getInputSlots(); + Preconditions.checkState(!exprInputSlots.isEmpty()); + + if (left.getOutputSet().containsAll(exprInputSlots)) { + leftFilter.add(expr); + } else if (right.getOutputSet().containsAll(exprInputSlots)) { + rightFilter.add(expr); + } else if (multiJoin.getOutputSet().containsAll(exprInputSlots)) { + remainingFilter.add(expr); + } else { + NereidsPlanner.LOG.error("invalid expression, exist slot that multiJoin don't contains."); + throw new RuntimeException("invalid expression, exist slot that multiJoin don't contains."); + } + } + + return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>( + multiJoinHandleChildren.getJoinType(), + ExpressionUtils.EMPTY_CONDITION, + multiJoinHandleChildren.getNotInnerJoinConditions(), + PlanUtils.filterOrSelf(leftFilter, left), PlanUtils.filterOrSelf(rightFilter, right))); + } + + // following this multiJoin just contain INNER/CROSS. + List joinFilter = multiJoinHandleChildren.getJoinFilter(); + + Plan left = multiJoinHandleChildren.child(0); + List candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()); + + LogicalJoin join = findInnerJoin(left, candidates, joinFilter); + List newInputs = Lists.newArrayList(); + newInputs.add(join); + newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList())); + + joinFilter.removeAll(join.getHashJoinConjuncts()); + joinFilter.removeAll(join.getOtherJoinConjuncts()); + // TODO(wj): eliminate this recursion. + return multiJoinToJoin(new MultiJoin( + newInputs, + joinFilter, + JoinType.INNER_JOIN, + ExpressionUtils.EMPTY_CONDITION)); + } + + /** + * Returns whether an input can be merged without changing semantics. + * + * @param input input into a MultiJoin or (GroupPlan|LogicalFilter) + * @param changeChildren generate nullable or one side child not exist. + * @return true if the input can be combined into a parent MultiJoin + */ + private static boolean canCombine(Plan input, boolean changeChildren) { + return input instanceof MultiJoin + && ((MultiJoin) input).getJoinType().isInnerOrCrossJoin() + && !changeChildren; + } + + private Set getJoinOutput(Plan left, Plan right) { + HashSet joinOutput = new HashSet<>(); + joinOutput.addAll(left.getOutput()); + joinOutput.addAll(right.getOutput()); + return joinOutput; + } + + /** + * Find hash condition from joinFilter + * Get InnerJoin from left, right from [candidates]. + * + * @return InnerJoin or CrossJoin{left, last of [candidates]} + */ + private LogicalJoin findInnerJoin(Plan left, List candidates, + List joinFilter) { + Set leftOutputSet = left.getOutputSet(); + for (int i = 0; i < candidates.size(); i++) { + Plan candidate = candidates.get(i); + Set rightOutputSet = candidate.getOutputSet(); + + Set joinOutput = getJoinOutput(left, candidate); + + List currentJoinFilter = joinFilter.stream() + .filter(expr -> { + Set exprInputSlots = expr.getInputSlots(); + Preconditions.checkState(exprInputSlots.size() > 1, + "Predicate like table.col > 1 must have pushdown."); + if (leftOutputSet.containsAll(exprInputSlots)) { + return false; + } + if (rightOutputSet.containsAll(exprInputSlots)) { + return false; + } + + return joinOutput.containsAll(exprInputSlots); + }).collect(Collectors.toList()); + + Pair, List> pair = JoinUtils.extractExpressionForHashTable( + left.getOutput(), candidate.getOutput(), currentJoinFilter); + List hashJoinConditions = pair.first; + List otherJoinConditions = pair.second; + if (!hashJoinConditions.isEmpty()) { + return new LogicalJoin<>(JoinType.INNER_JOIN, + hashJoinConditions, otherJoinConditions, + left, candidate); + } + + if (i == candidates.size() - 1) { + return new LogicalJoin<>(JoinType.CROSS_JOIN, + hashJoinConditions, otherJoinConditions, + left, candidate); + } + } + throw new RuntimeException("findInnerJoin: can't reach here"); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java index 764eecd154..b9e3d909c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/JoinType.java @@ -92,6 +92,14 @@ public enum JoinType { return this == INNER_JOIN; } + public final boolean isInnerOrCrossJoin() { + return this == INNER_JOIN || this == CROSS_JOIN; + } + + public final boolean isLeftJoin() { + return this == LEFT_OUTER_JOIN || this == LEFT_ANTI_JOIN || this == LEFT_SEMI_JOIN; + } + public final boolean isRightJoin() { return this == RIGHT_OUTER_JOIN || this == RIGHT_ANTI_JOIN || this == RIGHT_SEMI_JOIN; } @@ -108,6 +116,14 @@ public enum JoinType { return this == RIGHT_OUTER_JOIN; } + public final boolean isLeftSemiOrAntiJoin() { + return this == LEFT_SEMI_JOIN || this == LEFT_ANTI_JOIN; + } + + public final boolean isRightSemiOrAntiJoin() { + return this == RIGHT_SEMI_JOIN || this == RIGHT_ANTI_JOIN; + } + public final boolean isSemiOrAntiJoin() { return this == LEFT_SEMI_JOIN || this == RIGHT_SEMI_JOIN || this == LEFT_ANTI_JOIN || this == RIGHT_ANTI_JOIN; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java index 9ef8a7eae9..5fb4564b32 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/PlanType.java @@ -42,6 +42,7 @@ public enum PlanType { LOGICAL_SELECT_HINT, LOGICAL_ASSERT_NUM_ROWS, LOGICAL_HAVING, + LOGICAL_MULTI_JOIN, GROUP_PLAN, // physical plan diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java index 9e5f624fec..0029c205f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFilter.java @@ -102,7 +102,7 @@ public class LogicalFilter extends LogicalUnary R accept(PlanVisitor visitor, C context) { - return visitor.visitLogicalFilter((LogicalFilter) this, context); + return visitor.visitLogicalFilter(this, context); } @Override diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java index 7aa28fe259..4a2bcb1cb8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java @@ -17,26 +17,15 @@ package org.apache.doris.nereids.datasets.ssb; -import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; -import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; -import org.apache.doris.nereids.util.PlanRewriter; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.ArrayList; import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; -public class SSBJoinReorderTest extends SSBTestBase { +public class SSBJoinReorderTest extends SSBTestBase implements PatternMatchSupported { @Test public void q4_1() { test( @@ -47,9 +36,11 @@ public class SSBJoinReorderTest extends SSBTestBase { "(lo_suppkey = s_suppkey)", "(lo_partkey = p_partkey)" ), - ImmutableList.of("(((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) " - + "= CAST('AMERICA' AS STRING))) AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) " - + "OR (CAST(p_mfgr AS STRING) = CAST('MFGR#2' AS STRING))))") + ImmutableList.of( + "(CAST(c_region AS STRING) = 'AMERICA')", + "(CAST(s_region AS STRING) = 'AMERICA')", + "((CAST(p_mfgr AS STRING) = 'MFGR#1') OR (CAST(p_mfgr AS STRING) = 'MFGR#2'))" + ) ); } @@ -64,10 +55,11 @@ public class SSBJoinReorderTest extends SSBTestBase { "(lo_partkey = p_partkey)" ), ImmutableList.of( - "((((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) " - + "= CAST('AMERICA' AS STRING))) AND ((d_year = 1997) OR (d_year = 1998))) " - + "AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) OR (CAST(p_mfgr AS STRING) " - + "= CAST('MFGR#2' AS STRING))))") + "((d_year = 1997) OR (d_year = 1998))", + "(CAST(c_region AS STRING) = 'AMERICA')", + "(CAST(s_region AS STRING) = 'AMERICA')", + "((CAST(p_mfgr AS STRING) = 'MFGR#1') OR (CAST(p_mfgr AS STRING) = 'MFGR#2'))" + ) ); } @@ -81,94 +73,31 @@ public class SSBJoinReorderTest extends SSBTestBase { "(lo_suppkey = s_suppkey)", "(lo_partkey = p_partkey)" ), - ImmutableList.of("(((CAST(s_nation AS STRING) = CAST('UNITED STATES' AS STRING)) AND ((d_year = 1997) " - + "OR (d_year = 1998))) AND (CAST(p_category AS STRING) = CAST('MFGR#14' AS STRING)))") + ImmutableList.of( + "((d_year = 1997) OR (d_year = 1998))", + "(CAST(s_nation AS STRING) = 'UNITED STATES')", + "(CAST(p_category AS STRING) = 'MFGR#14')" + ) ); } private void test(String sql, List expectJoinConditions, List expectFilterPredicates) { - LogicalPlan analyzed = analyze(sql); - LogicalPlan plan = testJoinReorder(analyzed); - System.out.println(plan.treeString()); - new PlanChecker(expectJoinConditions, expectFilterPredicates).check(plan); - } + PlanChecker planChecker = PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .printlnTree(); - private LogicalPlan testJoinReorder(LogicalPlan plan) { - return (LogicalPlan) PlanRewriter.topDownRewrite(plan, connectContext, new ReorderJoin()); - } - - private static class PlanChecker extends PlanVisitor { - private final List joinInputs = new ArrayList<>(); - private final List joins = new ArrayList<>(); - private final List filters = new ArrayList<>(); - // TODO: it's tricky to compare expression by string, use a graceful manner to do this in the future. - private final List expectJoinConditions; - private final List expectFilterPredicates; - - public PlanChecker(List expectJoinConditions, List expectFilterPredicates) { - this.expectJoinConditions = expectJoinConditions; - this.expectFilterPredicates = expectFilterPredicates; + for (String expectJoinCondition : expectJoinConditions) { + planChecker.matches( + innerLogicalJoin().when( + join -> join.getHashJoinConjuncts().get(0).toSql().equals(expectJoinCondition)) + ); } - public void check(Plan plan) { - plan.accept(this, new Context(null)); - - // check join table orders - Assertions.assertEquals( - ImmutableList.of("dates", "lineorder", "customer", "supplier", "part"), - joinInputs.stream().map(p -> p.getTable().getName()).collect(Collectors.toList())); - - // check join conditions - List actualJoinConditions = joins.stream().map(j -> { - Optional condition = j.getOnClauseCondition(); - return condition.map(Expression::toSql).orElse(""); - }).collect(Collectors.toList()); - - Assertions.assertEquals(expectJoinConditions, actualJoinConditions); - - // check filter predicates - List actualFilterPredicates = filters.stream() - .map(f -> f.getPredicates().toSql()).collect(Collectors.toList()); - Assertions.assertEquals(expectFilterPredicates, actualFilterPredicates); - } - - @Override - public Void visit(Plan plan, Context context) { - for (Plan child : plan.children()) { - child.accept(this, new Context(plan)); - } - return null; - } - - @Override - public Void visitLogicalRelation(LogicalRelation relation, Context context) { - if (context.parent instanceof LogicalJoin) { - joinInputs.add(relation); - } - return null; - } - - @Override - public Void visitLogicalFilter(LogicalFilter filter, Context context) { - filters.add(filter); - filter.child().accept(this, new Context(filter)); - return null; - } - - @Override - public Void visitLogicalJoin(LogicalJoin join, Context context) { - join.left().accept(this, new Context(join)); - join.right().accept(this, new Context(join)); - joins.add(join); - return null; - } - } - - private static class Context { - public final Plan parent; - - public Context(Plan parent) { - this.parent = parent; + for (String expectFilterPredicate : expectFilterPredicates) { + planChecker.matches( + logicalFilter().when(filter -> filter.getPredicates().toSql().equals(expectFilterPredicate)) + ); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/JoinReorderTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/JoinReorderTest.java deleted file mode 100644 index e0806564ef..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/JoinReorderTest.java +++ /dev/null @@ -1,48 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite.logical; - -import org.apache.doris.nereids.analyzer.UnboundRelation; -import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PatternMatchSupported; -import org.apache.doris.nereids.util.PlanChecker; - -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Test; - -public class JoinReorderTest implements PatternMatchSupported { - - /** - * To test do not throw unexpected exception when join type is not inner or cross. - */ - @Test - public void testWithOuterJoin() { - UnboundRelation relation1 = new UnboundRelation(Lists.newArrayList("db", "table1")); - UnboundRelation relation2 = new UnboundRelation(Lists.newArrayList("db", "table2")); - LogicalJoin outerJoin = new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN, relation1, relation2); - LogicalFilter logicalFilter = new LogicalFilter<>(BooleanLiteral.FALSE, outerJoin); - - PlanChecker.from(MemoTestUtils.createConnectContext(), logicalFilter) - .applyBottomUp(new ReorderJoin()) - .matches(logicalFilter(leftOuterLogicalJoin(unboundRelation(), unboundRelation()))); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java new file mode 100644 index 0000000000..ffb7e16510 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java @@ -0,0 +1,183 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class ReorderJoinTest implements PatternMatchSupported { + + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0); + + @Test + public void testLeftOuterJoin() { + ImmutableList plans = ImmutableList.of( + new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build(), + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build() + ); + + check(plans); + } + + @Test + public void testRightOuterJoin() { + ImmutableList plans = ImmutableList.of( + new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build(), + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build() + ); + + check(plans); + } + + @Test + public void testLeftSemiJoin() { + ImmutableList plans = ImmutableList.of( + new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build(), + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .hashJoinUsing(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build() + ); + + check(plans); + } + + @Test + public void testRightSemiJoin() { + ImmutableList plans = ImmutableList.of( + new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0))) + .build(), + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build() + ); + + check(plans); + } + + @Test + public void testFullOuterJoin() { + ImmutableList plans = ImmutableList.of( + new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0)) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build(), + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0)) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build() + ); + + check(plans); + } + + public void check(List plans) { + for (LogicalPlan plan : plans) { + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyBottomUp(new ReorderJoin()) + .matchesFromRoot( + logicalJoin( + logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()), + leafPlan() + ).whenNot(join -> join.getJoinType().isCrossJoin()) + ) + .printlnTree(); + } + } + + /** + * join + * crossjoin / \ + * / \ join D + * innerjoin innerjoin ──► / \ + * / \ / \ join C + * A B C D / \ + * A B + */ + @Test + public void testInnerOrCrossJoin() { + LogicalPlan leftJoin = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) + .build(); + LogicalPlan rightJoin = new LogicalPlanBuilder(scan3) + .hashJoinUsing(scan4, JoinType.INNER_JOIN, Pair.of(0, 0)) + .build(); + + LogicalPlan plan = new LogicalPlanBuilder(leftJoin) + .hashJoinEmptyOn(rightJoin, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0))) + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .applyBottomUp(new ReorderJoin()) + .matchesFromRoot( + logicalJoin( + logicalJoin( + logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()), + leafPlan() + ).whenNot(join -> join.getJoinType().isCrossJoin()), + leafPlan() + ).whenNot(join -> join.getJoinType().isCrossJoin()) + ) + .printlnTree(); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java similarity index 59% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java index 2d958a613b..4a52796734 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/SqlTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java @@ -23,9 +23,12 @@ import org.apache.doris.nereids.util.PatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; +import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; -public class SqlTest extends TestWithFeService implements PatternMatchSupported { +import java.util.List; + +public class MultiJoinTest extends TestWithFeService implements PatternMatchSupported { @Override protected void runBeforeAll() throws Exception { createDatabase("test"); @@ -68,17 +71,44 @@ public class SqlTest extends TestWithFeService implements PatternMatchSupported } @Test - void testSql() { - // String sql = "SELECT *" - // + " FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id" - // + " WHERE T1.id = T2.id"; - String sql = "SELECT *" - + " FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1" - + " WHERE T1.id = T2.id"; - PlanChecker.from(connectContext) - .analyze(sql) - .applyTopDown(new ReorderJoin()) - .implement() - .printlnTree(); + void testMultiJoinEliminateCross() { + List sqls = ImmutableList.builder() + .add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id") + .add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id") + .build(); + + for (String sql : sqls) { + PlanChecker.from(connectContext) + .analyze(sql) + .applyBottomUp(new ReorderJoin()) + .matches( + logicalJoin( + logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()), + leafPlan() + ).whenNot(join -> join.getJoinType().isCrossJoin()) + ) + .printlnTree(); + } + } + + @Test + void testMultiJoinExistCross() { + List sqls = ImmutableList.builder() + .add("SELECT * FROM T2 LEFT SEMI JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id > T2.id") + .build(); + + for (String sql : sqls) { + PlanChecker.from(connectContext) + .analyze(sql) + .applyBottomUp(new ReorderJoin()) + .matches( + logicalJoin( + logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()), + leafPlan() + ).when(join -> join.getJoinType().isCrossJoin()) + .whenNot(join -> join.getOtherJoinConjuncts().isEmpty()) + ) + .printlnTree(); + } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java index 94f5149bd3..219acbe669 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/AnalyzeWhereSubqueryTest.java @@ -81,7 +81,6 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte = "select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 )"; private final List testSql = ImmutableList.of( sql1, sql2, sql3, sql4, sql5, sql6, sql7, sql8, sql9, sql10 - ); @Override