[feature](Nereids): use Multi join to rearrange join to eliminate cross join by using predicate. (#13353)
This commit is contained in:
@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
|
||||
@ -61,7 +62,11 @@ public class FindHashConditionForJoin extends OneRewriteRuleFactory {
|
||||
.addAll(join.getHashJoinConjuncts())
|
||||
.addAll(extractedHashJoinConjuncts)
|
||||
.build();
|
||||
return new LogicalJoin<>(join.getJoinType(),
|
||||
JoinType joinType = join.getJoinType();
|
||||
if (joinType == JoinType.CROSS_JOIN && !combinedHashJoinConjuncts.isEmpty()) {
|
||||
joinType = JoinType.INNER_JOIN;
|
||||
}
|
||||
return new LogicalJoin<>(joinType,
|
||||
combinedHashJoinConjuncts,
|
||||
remainedNonHashJoinConjuncts,
|
||||
join.left(), join.right());
|
||||
|
||||
@ -17,34 +17,43 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.memo.GroupExpression;
|
||||
import org.apache.doris.nereids.properties.LogicalProperties;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.PlanType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.AbstractLogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* A MultiJoin represents a join of N inputs (NAry-Join).
|
||||
* The regular Join represent strictly binary input (Binary-Join).
|
||||
* <p>
|
||||
* One {@link MultiJoin} just contains one {@link JoinType} of SEMI/ANTI/OUTER Join.
|
||||
* <p>
|
||||
* joinType is FULL OUTER JOIN, children.size() == 2.
|
||||
* leftChild [FULL OUTER JOIN] rightChild.
|
||||
* <p>
|
||||
* joinType is LEFT (OUTER/SEMI/ANTI) JOIN,
|
||||
* children[0, last) [LEFT (OUTER/SEMI/ANTI) JOIN] lastChild.
|
||||
* eg: MJ([LOJ] A, B, C, D) is {A B C} [LOJ] {D}.
|
||||
* <p>
|
||||
* joinType is RIGHT (OUTER/SEMI/ANTI) JOIN,
|
||||
* firstChild [RIGHT (OUTER/SEMI/ANTI) JOIN] children[1, last].
|
||||
* eg: MJ([ROJ] A, B, C, D) is {A} [ROJ] {B C D}.
|
||||
*/
|
||||
public class MultiJoin extends PlanVisitor<Void, Void> {
|
||||
public class MultiJoin extends AbstractLogicalPlan {
|
||||
/*
|
||||
* topJoin
|
||||
* / \ MultiJoin
|
||||
@ -52,154 +61,125 @@ public class MultiJoin extends PlanVisitor<Void, Void> {
|
||||
* / \ A B C
|
||||
* A B
|
||||
*/
|
||||
public final List<Plan> joinInputs = new ArrayList<>();
|
||||
public final List<Expression> conjunctsForAllHashJoins = new ArrayList<>();
|
||||
private List<Expression> conjunctsKeepInFilter = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* reorderJoinsAccordingToConditions
|
||||
*
|
||||
* @return join or filter
|
||||
*/
|
||||
public Optional<Plan> reorderJoinsAccordingToConditions() {
|
||||
if (joinInputs.size() >= 2) {
|
||||
Plan root = reorderJoinsAccordingToConditions(joinInputs, conjunctsForAllHashJoins);
|
||||
return Optional.of(PlanUtils.filterOrSelf(conjunctsKeepInFilter, root));
|
||||
}
|
||||
return Optional.empty();
|
||||
// Push predicates into it.
|
||||
// But joinFilter shouldn't contain predicate which just contains one predicate like `T.key > 1`.
|
||||
// Because these predicate should be pushdown.
|
||||
private final List<Expression> joinFilter;
|
||||
// MultiJoin just contains one OUTER/SEMI/ANTI.
|
||||
private final JoinType joinType;
|
||||
// When contains one OUTER/SEMI/ANTI join, keep separately its condition.
|
||||
private final List<Expression> notInnerJoinConditions;
|
||||
|
||||
public MultiJoin(List<Plan> inputs, List<Expression> joinFilter, JoinType joinType,
|
||||
List<Expression> notInnerJoinConditions) {
|
||||
super(PlanType.LOGICAL_MULTI_JOIN, inputs.toArray(new Plan[0]));
|
||||
this.joinFilter = Objects.requireNonNull(joinFilter);
|
||||
this.joinType = joinType;
|
||||
this.notInnerJoinConditions = Objects.requireNonNull(notInnerJoinConditions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reorder join orders according to join conditions to eliminate cross join.
|
||||
* <p/>
|
||||
* Let's say we have input join tables: [t1, t2, t3] and
|
||||
* conjunctive predicates: [t1.id=t3.id, t2.id=t3.id]
|
||||
* The input join for t1 and t2 is cross join.
|
||||
* <p/>
|
||||
* The algorithm split join inputs into two groups: `left input` t1 and `candidate right input` [t2, t3].
|
||||
* Try to find an inner join from t1 and candidate right inputs [t2, t3], if any combination
|
||||
* of [Join(t1, t2), Join(t1, t3)] could be optimized to inner join according to the join conditions.
|
||||
* <p/>
|
||||
* As a result, Join(t1, t3) is an inner join.
|
||||
* Then the logic is applied to the rest of [Join(t1, t3), t2] recursively.
|
||||
*/
|
||||
private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) {
|
||||
if (joinInputs.size() == 2) {
|
||||
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
|
||||
joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
|
||||
joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
|
||||
conjuncts);
|
||||
List<Expression> joinConditions = pair.first;
|
||||
conjunctsKeepInFilter = pair.second;
|
||||
|
||||
return new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION,
|
||||
joinConditions,
|
||||
joinInputs.get(0), joinInputs.get(1));
|
||||
}
|
||||
// input size >= 3;
|
||||
Plan left = joinInputs.get(0);
|
||||
List<Plan> candidate = joinInputs.subList(1, joinInputs.size());
|
||||
|
||||
List<Slot> leftOutput = left.getOutput();
|
||||
Optional<Plan> rightOpt = candidate.stream().filter(right -> {
|
||||
List<Slot> rightOutput = right.getOutput();
|
||||
|
||||
Set<Slot> joinOutput = getJoinOutput(left, right);
|
||||
Optional<Expression> joinCond = conjuncts.stream()
|
||||
.filter(expr -> {
|
||||
Set<Slot> exprInputSlots = expr.getInputSlots();
|
||||
if (exprInputSlots.isEmpty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (new HashSet<>(leftOutput).containsAll(exprInputSlots)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (new HashSet<>(rightOutput).containsAll(exprInputSlots)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return joinOutput.containsAll(exprInputSlots);
|
||||
}).findFirst();
|
||||
return joinCond.isPresent();
|
||||
}).findFirst();
|
||||
|
||||
Plan right = rightOpt.orElseGet(() -> candidate.get(1));
|
||||
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
|
||||
joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
|
||||
joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
|
||||
conjuncts);
|
||||
List<Expression> joinConditions = pair.first;
|
||||
List<Expression> nonJoinConditions = pair.second;
|
||||
LogicalJoin join = new LogicalJoin<>(JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION,
|
||||
joinConditions, left, right);
|
||||
|
||||
List<Plan> newInputs = new ArrayList<>();
|
||||
newInputs.add(join);
|
||||
newInputs.addAll(candidate.stream().filter(plan -> !right.equals(plan)).collect(Collectors.toList()));
|
||||
return reorderJoinsAccordingToConditions(newInputs, nonJoinConditions);
|
||||
public JoinType getJoinType() {
|
||||
return joinType;
|
||||
}
|
||||
|
||||
private Map<Boolean, List<Expression>> splitConjuncts(List<Expression> conjuncts, Set<Slot> slots) {
|
||||
return conjuncts.stream().collect(Collectors.partitioningBy(
|
||||
// TODO: support non equal to conditions.
|
||||
expr -> expr instanceof EqualTo && slots.containsAll(expr.getInputSlots())));
|
||||
public List<Expression> getJoinFilter() {
|
||||
return joinFilter;
|
||||
}
|
||||
|
||||
private Set<Slot> getJoinOutput(Plan left, Plan right) {
|
||||
HashSet<Slot> joinOutput = new HashSet<>();
|
||||
joinOutput.addAll(left.getOutput());
|
||||
joinOutput.addAll(right.getOutput());
|
||||
return joinOutput;
|
||||
public List<Expression> getNotInnerJoinConditions() {
|
||||
return notInnerJoinConditions;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visit(Plan plan, Void context) {
|
||||
for (Plan child : plan.children()) {
|
||||
child.accept(this, context);
|
||||
}
|
||||
return null;
|
||||
public MultiJoin withChildren(List<Plan> children) {
|
||||
return new MultiJoin(children, joinFilter, joinType, notInnerJoinConditions);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
|
||||
Plan child = filter.child();
|
||||
if (child instanceof LogicalJoin) {
|
||||
conjunctsForAllHashJoins.addAll(ExpressionUtils.extractConjunction(filter.getPredicates()));
|
||||
public List<Slot> computeOutput() {
|
||||
Builder<Slot> builder = ImmutableList.builder();
|
||||
|
||||
if (joinType.isInnerOrCrossJoin()) {
|
||||
// INNER/CROSS
|
||||
for (Plan child : children) {
|
||||
builder.addAll(child.getOutput());
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
child.accept(this, context);
|
||||
return null;
|
||||
// FULL OUTER JOIN
|
||||
if (joinType.isFullOuterJoin()) {
|
||||
for (Plan child : children) {
|
||||
builder.addAll(child.getOutput().stream()
|
||||
.map(o -> o.withNullable(true))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
// RIGHT OUTER | RIGHT_SEMI/ANTI
|
||||
if (joinType.isRightJoin()) {
|
||||
// RIGHT OUTER
|
||||
if (joinType.isRightOuterJoin()) {
|
||||
builder.addAll(children.get(0).getOutput().stream()
|
||||
.map(o -> o.withNullable(true))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
for (int i = 1; i < children.size(); i++) {
|
||||
builder.addAll(children.get(i).getOutput());
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
// LEFT OUTER | LEFT_SEMI/ANTI
|
||||
if (joinType.isLeftJoin()) {
|
||||
for (int i = 0; i < children.size() - 1; i++) {
|
||||
builder.addAll(children.get(i).getOutput());
|
||||
}
|
||||
// LEFT OUTER
|
||||
if (joinType.isLeftOuterJoin()) {
|
||||
builder.addAll(children.get(arity() - 1).getOutput().stream()
|
||||
.map(o -> o.withNullable(true))
|
||||
.collect(Collectors.toList()));
|
||||
}
|
||||
|
||||
return builder.build();
|
||||
}
|
||||
|
||||
throw new RuntimeException("unreachable");
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
|
||||
if (join.getJoinType() != JoinType.CROSS_JOIN && join.getJoinType() != JoinType.INNER_JOIN) {
|
||||
return null;
|
||||
}
|
||||
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
|
||||
throw new RuntimeException("multiJoin can't invoke accept");
|
||||
}
|
||||
|
||||
join.left().accept(this, context);
|
||||
join.right().accept(this, context);
|
||||
@Override
|
||||
public List<? extends Expression> getExpressions() {
|
||||
return new Builder<Expression>()
|
||||
.addAll(joinFilter)
|
||||
.addAll(notInnerJoinConditions)
|
||||
.build();
|
||||
}
|
||||
|
||||
conjunctsForAllHashJoins.addAll(join.getHashJoinConjuncts());
|
||||
conjunctsForAllHashJoins.addAll(join.getOtherJoinConjuncts());
|
||||
@Override
|
||||
public Plan withGroupExpression(Optional<GroupExpression> groupExpression) {
|
||||
throw new RuntimeException("multiJoin can't invoke withGroupExpression");
|
||||
}
|
||||
|
||||
Plan leftChild = join.left();
|
||||
if (join.left() instanceof LogicalFilter) {
|
||||
leftChild = join.left().child(0);
|
||||
}
|
||||
if (leftChild instanceof GroupPlan) {
|
||||
joinInputs.add(join.left());
|
||||
}
|
||||
Plan rightChild = join.right();
|
||||
if (join.right() instanceof LogicalFilter) {
|
||||
rightChild = join.right().child(0);
|
||||
}
|
||||
if (rightChild instanceof GroupPlan) {
|
||||
joinInputs.add(join.right());
|
||||
}
|
||||
return null;
|
||||
@Override
|
||||
public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) {
|
||||
throw new RuntimeException("multiJoin can't invoke withLogicalProperties");
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Utils.toSqlString("MultiJoin",
|
||||
"joinType", joinType,
|
||||
"joinFilter", joinFilter,
|
||||
"notInnerJoinConditions", notInnerJoinConditions
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -17,12 +17,31 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
import org.apache.doris.nereids.util.PlanUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableList.Builder;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* Try to eliminate cross join via finding join conditions in filters and change the join orders.
|
||||
@ -37,21 +56,308 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
* SELECT * FROM t1 JOIN t3 ON t1.id=t3.id JOIN t2 ON t2.id=t3.id
|
||||
* </pre>
|
||||
* </p>
|
||||
* TODO: This is tested by SSB queries currently, add more `unit` test for this rule
|
||||
* when we have a plan building and comparing framework.
|
||||
* Using the {@link MultiJoin} to complete this task.
|
||||
* {Join cluster}: contain multiple join with filter inside.
|
||||
* <ul>
|
||||
* <li> {Join cluster} to MultiJoin</li>
|
||||
* <li> MultiJoin to {Join cluster}</li>
|
||||
* </ul>
|
||||
*/
|
||||
public class ReorderJoin extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalFilter(subTree(LogicalJoin.class, LogicalFilter.class)).thenApply(ctx -> {
|
||||
LogicalFilter<Plan> filter = ctx.root;
|
||||
if (!ctx.cascadesContext.getConnectContext().getSessionVariable()
|
||||
.isEnableNereidsReorderToEliminateCrossJoin()) {
|
||||
return filter;
|
||||
}
|
||||
MultiJoin multiJoin = new MultiJoin();
|
||||
filter.accept(multiJoin, null);
|
||||
return multiJoin.reorderJoinsAccordingToConditions().orElse(filter);
|
||||
|
||||
Plan plan = joinToMultiJoin(filter);
|
||||
Preconditions.checkState(plan instanceof MultiJoin);
|
||||
MultiJoin multiJoin = (MultiJoin) plan;
|
||||
Plan after = multiJoinToJoin(multiJoin);
|
||||
return after;
|
||||
}).toRule(RuleType.REORDER_JOIN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively convert to
|
||||
* {@link LogicalJoin} or
|
||||
* {@link LogicalFilter}--{@link LogicalJoin}
|
||||
* --> {@link MultiJoin}
|
||||
*/
|
||||
public Plan joinToMultiJoin(Plan plan) {
|
||||
// subtree can't specify the end of Pattern. so end can be GroupPlan or Filter
|
||||
if (plan instanceof GroupPlan
|
||||
|| (plan instanceof LogicalFilter && plan.child(0) instanceof GroupPlan)) {
|
||||
return plan;
|
||||
}
|
||||
|
||||
List<Plan> inputs = Lists.newArrayList();
|
||||
List<Expression> joinFilter = Lists.newArrayList();
|
||||
List<Expression> notInnerJoinConditions = Lists.newArrayList();
|
||||
|
||||
LogicalJoin<?, ?> join;
|
||||
// Implicit rely on {rule: MergeFilters}, so don't exist filter--filter--join.
|
||||
if (plan instanceof LogicalFilter) {
|
||||
LogicalFilter<?> filter = (LogicalFilter<?>) plan;
|
||||
joinFilter.addAll(ExpressionUtils.extractConjunction(filter.getPredicates()));
|
||||
join = (LogicalJoin<?, ?>) filter.child();
|
||||
} else {
|
||||
join = (LogicalJoin<?, ?>) plan;
|
||||
}
|
||||
|
||||
if (join.getJoinType().isInnerOrCrossJoin()) {
|
||||
joinFilter.addAll(join.getHashJoinConjuncts());
|
||||
joinFilter.addAll(join.getOtherJoinConjuncts());
|
||||
} else {
|
||||
notInnerJoinConditions.addAll(join.getHashJoinConjuncts());
|
||||
notInnerJoinConditions.addAll(join.getOtherJoinConjuncts());
|
||||
}
|
||||
|
||||
// recursively convert children.
|
||||
Plan left = joinToMultiJoin(join.left());
|
||||
Plan right = joinToMultiJoin(join.right());
|
||||
|
||||
boolean changeLeft = join.getJoinType().isRightJoin()
|
||||
|| join.getJoinType().isFullOuterJoin();
|
||||
if (canCombine(left, changeLeft)) {
|
||||
MultiJoin leftMultiJoin = (MultiJoin) left;
|
||||
inputs.addAll(leftMultiJoin.children());
|
||||
joinFilter.addAll(leftMultiJoin.getJoinFilter());
|
||||
} else {
|
||||
inputs.add(left);
|
||||
}
|
||||
|
||||
boolean changeRight = join.getJoinType().isLeftJoin()
|
||||
|| join.getJoinType().isFullOuterJoin();
|
||||
if (canCombine(right, changeRight)) {
|
||||
MultiJoin rightMultiJoin = (MultiJoin) right;
|
||||
inputs.addAll(rightMultiJoin.children());
|
||||
joinFilter.addAll(rightMultiJoin.getJoinFilter());
|
||||
} else {
|
||||
inputs.add(right);
|
||||
}
|
||||
|
||||
return new MultiJoin(
|
||||
inputs,
|
||||
joinFilter,
|
||||
join.getJoinType(),
|
||||
notInnerJoinConditions);
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively convert to
|
||||
* {@link MultiJoin}
|
||||
* -->
|
||||
* {@link LogicalJoin} or
|
||||
* {@link LogicalFilter}--{@link LogicalJoin}
|
||||
* <p>
|
||||
* When all input is CROSS/Inner Join, all join will be flattened.
|
||||
* Otherwise, we will split {join cluster} into multiple {@link MultiJoin}.
|
||||
* <p>
|
||||
* Here are examples of the {@link MultiJoin}s constructed after this rules has been applied.
|
||||
* <p>
|
||||
* simple example:
|
||||
* <ul>
|
||||
* <li>A JOIN B --> MJ(A, B)
|
||||
* <li>A JOIN B JOIN C JOIN D --> MJ(A, B, C, D)
|
||||
* <li>A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([LOJ/LSJ/LAJ]A, B)
|
||||
* <li>A LEFT (OUTER/SEMI/ANTI) JOIN B --> MJ([ROJ/RSJ/RAJ]A, B)
|
||||
* <li>A FULL JOIN B --> MJ[FOJ](A, B)
|
||||
* </ul>
|
||||
* </p>
|
||||
* <p>
|
||||
* complex example:
|
||||
* <ul>
|
||||
* <li>A LEFT OUTER JOIN (B JOIN C) --> MJ([LOJ]A, MJ(B, C)))
|
||||
* <li>(A JOIN B) LEFT JOIN C --> MJ(A, B, C)
|
||||
* <li>(A LEFT OUTER JOIN B) JOIN C --> MJ(MJ(A, B), C)
|
||||
* <li>A LEFT JOIN (B FULL JOIN C) --> MJ(A, MJ[full](B, C))
|
||||
* <li>(A LEFT JOIN B) FULL JOIN (C RIGHT JOIN D) --> MJ[full](MJ(A, B), MJ(C, D))
|
||||
* </ul>
|
||||
* </p>
|
||||
* more complex example:
|
||||
* <ul>
|
||||
* <li> A JOIN B JOIN C LEFT JOIN D --> MJ([LOJ]A, B, C, D)
|
||||
* <li> A JOIN B JOIN C LEFT JOIN (D JOIN F) --> MJ([LOJ]A, B, C, MJ(D, F))
|
||||
* <li> A RIGHT JOIN (B JOIN C JOIN D)--> MJ([ROJ]A, B, C, D)
|
||||
* <li> A JOIN B RIGHT JOIN (C JOIN D) --> MJ(A, B, MJ([ROJ]C, D))
|
||||
* </ul>
|
||||
* </p>
|
||||
* <p>
|
||||
* Graphic presentation:
|
||||
* A JOIN B JOIN C LEFT JOIN D JOIN F
|
||||
* left left│
|
||||
* A B C D F ──► A B C │ D F ──► MJ(LOJ A,B,C,MJ(DF)
|
||||
* <p>
|
||||
* A JOIN B RIGHT JOIN C JOIN D JOIN F
|
||||
* right │right
|
||||
* A B C D F ──► A B │ C D F ──► MJ(A,B,MJ(ROJ C,D,F)
|
||||
* <p>
|
||||
* (A JOIN B JOIN C) FULL JOIN (D JOIN F)
|
||||
* full │
|
||||
* A B C D F ──► A B C │ D F ──► MJ(FOJ MJ(A,B,C) MJ(D,F))
|
||||
* </p>
|
||||
*/
|
||||
public Plan multiJoinToJoin(MultiJoin multiJoin) {
|
||||
if (multiJoin.arity() == 1) {
|
||||
return PlanUtils.filterOrSelf(multiJoin.getJoinFilter(), multiJoin.child(0));
|
||||
}
|
||||
|
||||
Builder<Plan> builder = ImmutableList.builder();
|
||||
// recursively hanlde multiJoin children.
|
||||
for (Plan child : multiJoin.children()) {
|
||||
if (child instanceof MultiJoin) {
|
||||
MultiJoin childMultiJoin = (MultiJoin) child;
|
||||
builder.add(multiJoinToJoin(childMultiJoin));
|
||||
} else {
|
||||
builder.add(child);
|
||||
}
|
||||
}
|
||||
MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build());
|
||||
|
||||
if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) {
|
||||
List<Expression> leftFilter = Lists.newArrayList();
|
||||
List<Expression> rightFilter = Lists.newArrayList();
|
||||
List<Expression> remainingFilter = Lists.newArrayList();
|
||||
Plan left;
|
||||
Plan right;
|
||||
if (multiJoinHandleChildren.getJoinType().isLeftJoin()) {
|
||||
left = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
|
||||
leftFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1);
|
||||
} else if (multiJoinHandleChildren.getJoinType().isRightJoin()) {
|
||||
left = multiJoinHandleChildren.child(0);
|
||||
right = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
|
||||
rightFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
} else {
|
||||
left = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1),
|
||||
leftFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
right = multiJoinToJoin(new MultiJoin(
|
||||
multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()),
|
||||
rightFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
}
|
||||
|
||||
// split filter
|
||||
for (Expression expr : multiJoinHandleChildren.getJoinFilter()) {
|
||||
Set<Slot> exprInputSlots = expr.getInputSlots();
|
||||
Preconditions.checkState(!exprInputSlots.isEmpty());
|
||||
|
||||
if (left.getOutputSet().containsAll(exprInputSlots)) {
|
||||
leftFilter.add(expr);
|
||||
} else if (right.getOutputSet().containsAll(exprInputSlots)) {
|
||||
rightFilter.add(expr);
|
||||
} else if (multiJoin.getOutputSet().containsAll(exprInputSlots)) {
|
||||
remainingFilter.add(expr);
|
||||
} else {
|
||||
NereidsPlanner.LOG.error("invalid expression, exist slot that multiJoin don't contains.");
|
||||
throw new RuntimeException("invalid expression, exist slot that multiJoin don't contains.");
|
||||
}
|
||||
}
|
||||
|
||||
return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>(
|
||||
multiJoinHandleChildren.getJoinType(),
|
||||
ExpressionUtils.EMPTY_CONDITION,
|
||||
multiJoinHandleChildren.getNotInnerJoinConditions(),
|
||||
PlanUtils.filterOrSelf(leftFilter, left), PlanUtils.filterOrSelf(rightFilter, right)));
|
||||
}
|
||||
|
||||
// following this multiJoin just contain INNER/CROSS.
|
||||
List<Expression> joinFilter = multiJoinHandleChildren.getJoinFilter();
|
||||
|
||||
Plan left = multiJoinHandleChildren.child(0);
|
||||
List<Plan> candidates = multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity());
|
||||
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join = findInnerJoin(left, candidates, joinFilter);
|
||||
List<Plan> newInputs = Lists.newArrayList();
|
||||
newInputs.add(join);
|
||||
newInputs.addAll(candidates.stream().filter(plan -> !join.right().equals(plan)).collect(Collectors.toList()));
|
||||
|
||||
joinFilter.removeAll(join.getHashJoinConjuncts());
|
||||
joinFilter.removeAll(join.getOtherJoinConjuncts());
|
||||
// TODO(wj): eliminate this recursion.
|
||||
return multiJoinToJoin(new MultiJoin(
|
||||
newInputs,
|
||||
joinFilter,
|
||||
JoinType.INNER_JOIN,
|
||||
ExpressionUtils.EMPTY_CONDITION));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether an input can be merged without changing semantics.
|
||||
*
|
||||
* @param input input into a MultiJoin or (GroupPlan|LogicalFilter)
|
||||
* @param changeChildren generate nullable or one side child not exist.
|
||||
* @return true if the input can be combined into a parent MultiJoin
|
||||
*/
|
||||
private static boolean canCombine(Plan input, boolean changeChildren) {
|
||||
return input instanceof MultiJoin
|
||||
&& ((MultiJoin) input).getJoinType().isInnerOrCrossJoin()
|
||||
&& !changeChildren;
|
||||
}
|
||||
|
||||
private Set<Slot> getJoinOutput(Plan left, Plan right) {
|
||||
HashSet<Slot> joinOutput = new HashSet<>();
|
||||
joinOutput.addAll(left.getOutput());
|
||||
joinOutput.addAll(right.getOutput());
|
||||
return joinOutput;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find hash condition from joinFilter
|
||||
* Get InnerJoin from left, right from [candidates].
|
||||
*
|
||||
* @return InnerJoin or CrossJoin{left, last of [candidates]}
|
||||
*/
|
||||
private LogicalJoin<? extends Plan, ? extends Plan> findInnerJoin(Plan left, List<Plan> candidates,
|
||||
List<Expression> joinFilter) {
|
||||
Set<Slot> leftOutputSet = left.getOutputSet();
|
||||
for (int i = 0; i < candidates.size(); i++) {
|
||||
Plan candidate = candidates.get(i);
|
||||
Set<Slot> rightOutputSet = candidate.getOutputSet();
|
||||
|
||||
Set<Slot> joinOutput = getJoinOutput(left, candidate);
|
||||
|
||||
List<Expression> currentJoinFilter = joinFilter.stream()
|
||||
.filter(expr -> {
|
||||
Set<Slot> exprInputSlots = expr.getInputSlots();
|
||||
Preconditions.checkState(exprInputSlots.size() > 1,
|
||||
"Predicate like table.col > 1 must have pushdown.");
|
||||
if (leftOutputSet.containsAll(exprInputSlots)) {
|
||||
return false;
|
||||
}
|
||||
if (rightOutputSet.containsAll(exprInputSlots)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return joinOutput.containsAll(exprInputSlots);
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
|
||||
left.getOutput(), candidate.getOutput(), currentJoinFilter);
|
||||
List<Expression> hashJoinConditions = pair.first;
|
||||
List<Expression> otherJoinConditions = pair.second;
|
||||
if (!hashJoinConditions.isEmpty()) {
|
||||
return new LogicalJoin<>(JoinType.INNER_JOIN,
|
||||
hashJoinConditions, otherJoinConditions,
|
||||
left, candidate);
|
||||
}
|
||||
|
||||
if (i == candidates.size() - 1) {
|
||||
return new LogicalJoin<>(JoinType.CROSS_JOIN,
|
||||
hashJoinConditions, otherJoinConditions,
|
||||
left, candidate);
|
||||
}
|
||||
}
|
||||
throw new RuntimeException("findInnerJoin: can't reach here");
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,6 +92,14 @@ public enum JoinType {
|
||||
return this == INNER_JOIN;
|
||||
}
|
||||
|
||||
public final boolean isInnerOrCrossJoin() {
|
||||
return this == INNER_JOIN || this == CROSS_JOIN;
|
||||
}
|
||||
|
||||
public final boolean isLeftJoin() {
|
||||
return this == LEFT_OUTER_JOIN || this == LEFT_ANTI_JOIN || this == LEFT_SEMI_JOIN;
|
||||
}
|
||||
|
||||
public final boolean isRightJoin() {
|
||||
return this == RIGHT_OUTER_JOIN || this == RIGHT_ANTI_JOIN || this == RIGHT_SEMI_JOIN;
|
||||
}
|
||||
@ -108,6 +116,14 @@ public enum JoinType {
|
||||
return this == RIGHT_OUTER_JOIN;
|
||||
}
|
||||
|
||||
public final boolean isLeftSemiOrAntiJoin() {
|
||||
return this == LEFT_SEMI_JOIN || this == LEFT_ANTI_JOIN;
|
||||
}
|
||||
|
||||
public final boolean isRightSemiOrAntiJoin() {
|
||||
return this == RIGHT_SEMI_JOIN || this == RIGHT_ANTI_JOIN;
|
||||
}
|
||||
|
||||
public final boolean isSemiOrAntiJoin() {
|
||||
return this == LEFT_SEMI_JOIN || this == RIGHT_SEMI_JOIN || this == LEFT_ANTI_JOIN || this == RIGHT_ANTI_JOIN;
|
||||
}
|
||||
|
||||
@ -42,6 +42,7 @@ public enum PlanType {
|
||||
LOGICAL_SELECT_HINT,
|
||||
LOGICAL_ASSERT_NUM_ROWS,
|
||||
LOGICAL_HAVING,
|
||||
LOGICAL_MULTI_JOIN,
|
||||
GROUP_PLAN,
|
||||
|
||||
// physical plan
|
||||
|
||||
@ -102,7 +102,7 @@ public class LogicalFilter<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(PlanVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitLogicalFilter((LogicalFilter<Plan>) this, context);
|
||||
return visitor.visitLogicalFilter(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -17,26 +17,15 @@
|
||||
|
||||
package org.apache.doris.nereids.datasets.ssb;
|
||||
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
|
||||
import org.apache.doris.nereids.util.PlanRewriter;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class SSBJoinReorderTest extends SSBTestBase {
|
||||
public class SSBJoinReorderTest extends SSBTestBase implements PatternMatchSupported {
|
||||
@Test
|
||||
public void q4_1() {
|
||||
test(
|
||||
@ -47,9 +36,11 @@ public class SSBJoinReorderTest extends SSBTestBase {
|
||||
"(lo_suppkey = s_suppkey)",
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of("(((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
|
||||
+ "= CAST('AMERICA' AS STRING))) AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) "
|
||||
+ "OR (CAST(p_mfgr AS STRING) = CAST('MFGR#2' AS STRING))))")
|
||||
ImmutableList.of(
|
||||
"(CAST(c_region AS STRING) = 'AMERICA')",
|
||||
"(CAST(s_region AS STRING) = 'AMERICA')",
|
||||
"((CAST(p_mfgr AS STRING) = 'MFGR#1') OR (CAST(p_mfgr AS STRING) = 'MFGR#2'))"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ -64,10 +55,11 @@ public class SSBJoinReorderTest extends SSBTestBase {
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of(
|
||||
"((((CAST(c_region AS STRING) = CAST('AMERICA' AS STRING)) AND (CAST(s_region AS STRING) "
|
||||
+ "= CAST('AMERICA' AS STRING))) AND ((d_year = 1997) OR (d_year = 1998))) "
|
||||
+ "AND ((CAST(p_mfgr AS STRING) = CAST('MFGR#1' AS STRING)) OR (CAST(p_mfgr AS STRING) "
|
||||
+ "= CAST('MFGR#2' AS STRING))))")
|
||||
"((d_year = 1997) OR (d_year = 1998))",
|
||||
"(CAST(c_region AS STRING) = 'AMERICA')",
|
||||
"(CAST(s_region AS STRING) = 'AMERICA')",
|
||||
"((CAST(p_mfgr AS STRING) = 'MFGR#1') OR (CAST(p_mfgr AS STRING) = 'MFGR#2'))"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@ -81,94 +73,31 @@ public class SSBJoinReorderTest extends SSBTestBase {
|
||||
"(lo_suppkey = s_suppkey)",
|
||||
"(lo_partkey = p_partkey)"
|
||||
),
|
||||
ImmutableList.of("(((CAST(s_nation AS STRING) = CAST('UNITED STATES' AS STRING)) AND ((d_year = 1997) "
|
||||
+ "OR (d_year = 1998))) AND (CAST(p_category AS STRING) = CAST('MFGR#14' AS STRING)))")
|
||||
ImmutableList.of(
|
||||
"((d_year = 1997) OR (d_year = 1998))",
|
||||
"(CAST(s_nation AS STRING) = 'UNITED STATES')",
|
||||
"(CAST(p_category AS STRING) = 'MFGR#14')"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private void test(String sql, List<String> expectJoinConditions, List<String> expectFilterPredicates) {
|
||||
LogicalPlan analyzed = analyze(sql);
|
||||
LogicalPlan plan = testJoinReorder(analyzed);
|
||||
System.out.println(plan.treeString());
|
||||
new PlanChecker(expectJoinConditions, expectFilterPredicates).check(plan);
|
||||
}
|
||||
PlanChecker planChecker = PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.rewrite()
|
||||
.printlnTree();
|
||||
|
||||
private LogicalPlan testJoinReorder(LogicalPlan plan) {
|
||||
return (LogicalPlan) PlanRewriter.topDownRewrite(plan, connectContext, new ReorderJoin());
|
||||
}
|
||||
|
||||
private static class PlanChecker extends PlanVisitor<Void, Context> {
|
||||
private final List<LogicalRelation> joinInputs = new ArrayList<>();
|
||||
private final List<LogicalJoin> joins = new ArrayList<>();
|
||||
private final List<LogicalFilter> filters = new ArrayList<>();
|
||||
// TODO: it's tricky to compare expression by string, use a graceful manner to do this in the future.
|
||||
private final List<String> expectJoinConditions;
|
||||
private final List<String> expectFilterPredicates;
|
||||
|
||||
public PlanChecker(List<String> expectJoinConditions, List<String> expectFilterPredicates) {
|
||||
this.expectJoinConditions = expectJoinConditions;
|
||||
this.expectFilterPredicates = expectFilterPredicates;
|
||||
for (String expectJoinCondition : expectJoinConditions) {
|
||||
planChecker.matches(
|
||||
innerLogicalJoin().when(
|
||||
join -> join.getHashJoinConjuncts().get(0).toSql().equals(expectJoinCondition))
|
||||
);
|
||||
}
|
||||
|
||||
public void check(Plan plan) {
|
||||
plan.accept(this, new Context(null));
|
||||
|
||||
// check join table orders
|
||||
Assertions.assertEquals(
|
||||
ImmutableList.of("dates", "lineorder", "customer", "supplier", "part"),
|
||||
joinInputs.stream().map(p -> p.getTable().getName()).collect(Collectors.toList()));
|
||||
|
||||
// check join conditions
|
||||
List<String> actualJoinConditions = joins.stream().map(j -> {
|
||||
Optional<Expression> condition = j.getOnClauseCondition();
|
||||
return condition.map(Expression::toSql).orElse("");
|
||||
}).collect(Collectors.toList());
|
||||
|
||||
Assertions.assertEquals(expectJoinConditions, actualJoinConditions);
|
||||
|
||||
// check filter predicates
|
||||
List<String> actualFilterPredicates = filters.stream()
|
||||
.map(f -> f.getPredicates().toSql()).collect(Collectors.toList());
|
||||
Assertions.assertEquals(expectFilterPredicates, actualFilterPredicates);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visit(Plan plan, Context context) {
|
||||
for (Plan child : plan.children()) {
|
||||
child.accept(this, new Context(plan));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitLogicalRelation(LogicalRelation relation, Context context) {
|
||||
if (context.parent instanceof LogicalJoin) {
|
||||
joinInputs.add(relation);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitLogicalFilter(LogicalFilter<? extends Plan> filter, Context context) {
|
||||
filters.add(filter);
|
||||
filter.child().accept(this, new Context(filter));
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Void visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Context context) {
|
||||
join.left().accept(this, new Context(join));
|
||||
join.right().accept(this, new Context(join));
|
||||
joins.add(join);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private static class Context {
|
||||
public final Plan parent;
|
||||
|
||||
public Context(Plan parent) {
|
||||
this.parent = parent;
|
||||
for (String expectFilterPredicate : expectFilterPredicates) {
|
||||
planChecker.matches(
|
||||
logicalFilter().when(filter -> filter.getPredicates().toSql().equals(expectFilterPredicate))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,48 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.nereids.analyzer.UnboundRelation;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class JoinReorderTest implements PatternMatchSupported {
|
||||
|
||||
/**
|
||||
* To test do not throw unexpected exception when join type is not inner or cross.
|
||||
*/
|
||||
@Test
|
||||
public void testWithOuterJoin() {
|
||||
UnboundRelation relation1 = new UnboundRelation(Lists.newArrayList("db", "table1"));
|
||||
UnboundRelation relation2 = new UnboundRelation(Lists.newArrayList("db", "table2"));
|
||||
LogicalJoin outerJoin = new LogicalJoin<>(JoinType.LEFT_OUTER_JOIN, relation1, relation2);
|
||||
LogicalFilter logicalFilter = new LogicalFilter<>(BooleanLiteral.FALSE, outerJoin);
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), logicalFilter)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.matches(logicalFilter(leftOuterLogicalJoin(unboundRelation(), unboundRelation())));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,183 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
class ReorderJoinTest implements PatternMatchSupported {
|
||||
|
||||
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
private final LogicalOlapScan scan4 = PlanConstructor.newLogicalOlapScan(3, "t4", 0);
|
||||
|
||||
@Test
|
||||
public void testLeftOuterJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
|
||||
check(plans);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRightOuterJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
|
||||
check(plans);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLeftSemiJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.hashJoinUsing(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
|
||||
check(plans);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRightSemiJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
|
||||
check(plans);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testFullOuterJoin() {
|
||||
ImmutableList<LogicalPlan> plans = ImmutableList.of(
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0))
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build(),
|
||||
new LogicalPlanBuilder(scan1)
|
||||
.hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN)
|
||||
.hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0))
|
||||
.filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0)))
|
||||
.build()
|
||||
);
|
||||
|
||||
check(plans);
|
||||
}
|
||||
|
||||
public void check(List<LogicalPlan> plans) {
|
||||
for (LogicalPlan plan : plans) {
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.matchesFromRoot(
|
||||
logicalJoin(
|
||||
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).whenNot(join -> join.getJoinType().isCrossJoin())
|
||||
)
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* join
|
||||
* crossjoin / \
|
||||
* / \ join D
|
||||
* innerjoin innerjoin ──► / \
|
||||
* / \ / \ join C
|
||||
* A B C D / \
|
||||
* A B
|
||||
*/
|
||||
@Test
|
||||
public void testInnerOrCrossJoin() {
|
||||
LogicalPlan leftJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.build();
|
||||
LogicalPlan rightJoin = new LogicalPlanBuilder(scan3)
|
||||
.hashJoinUsing(scan4, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.build();
|
||||
|
||||
LogicalPlan plan = new LogicalPlanBuilder(leftJoin)
|
||||
.hashJoinEmptyOn(rightJoin, JoinType.CROSS_JOIN)
|
||||
.filter(new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.matchesFromRoot(
|
||||
logicalJoin(
|
||||
logicalJoin(
|
||||
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).whenNot(join -> join.getJoinType().isCrossJoin())
|
||||
)
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
@ -23,9 +23,12 @@ import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SqlTest extends TestWithFeService implements PatternMatchSupported {
|
||||
import java.util.List;
|
||||
|
||||
public class MultiJoinTest extends TestWithFeService implements PatternMatchSupported {
|
||||
@Override
|
||||
protected void runBeforeAll() throws Exception {
|
||||
createDatabase("test");
|
||||
@ -68,17 +71,44 @@ public class SqlTest extends TestWithFeService implements PatternMatchSupported
|
||||
}
|
||||
|
||||
@Test
|
||||
void testSql() {
|
||||
// String sql = "SELECT *"
|
||||
// + " FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id"
|
||||
// + " WHERE T1.id = T2.id";
|
||||
String sql = "SELECT *"
|
||||
+ " FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1"
|
||||
+ " WHERE T1.id = T2.id";
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.applyTopDown(new ReorderJoin())
|
||||
.implement()
|
||||
.printlnTree();
|
||||
void testMultiJoinEliminateCross() {
|
||||
List<String> sqls = ImmutableList.<String>builder()
|
||||
.add("SELECT * FROM T1, T2 LEFT JOIN T3 ON T2.id = T3.id WHERE T1.id = T2.id")
|
||||
.add("SELECT * FROM T2 LEFT JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id = T2.id")
|
||||
.build();
|
||||
|
||||
for (String sql : sqls) {
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.matches(
|
||||
logicalJoin(
|
||||
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).whenNot(join -> join.getJoinType().isCrossJoin())
|
||||
)
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMultiJoinExistCross() {
|
||||
List<String> sqls = ImmutableList.<String>builder()
|
||||
.add("SELECT * FROM T2 LEFT SEMI JOIN T3 ON T2.id = T3.id, T1 WHERE T1.id > T2.id")
|
||||
.build();
|
||||
|
||||
for (String sql : sqls) {
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze(sql)
|
||||
.applyBottomUp(new ReorderJoin())
|
||||
.matches(
|
||||
logicalJoin(
|
||||
logicalJoin().whenNot(join -> join.getJoinType().isCrossJoin()),
|
||||
leafPlan()
|
||||
).when(join -> join.getJoinType().isCrossJoin())
|
||||
.whenNot(join -> join.getOtherJoinConjuncts().isEmpty())
|
||||
)
|
||||
.printlnTree();
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -81,7 +81,6 @@ public class AnalyzeWhereSubqueryTest extends TestWithFeService implements Patte
|
||||
= "select * from t6 where t6.k1 < (select max(aa) from (select v1 as aa from t7 where t6.k2=t7.v2) t2 )";
|
||||
private final List<String> testSql = ImmutableList.of(
|
||||
sql1, sql2, sql3, sql4, sql5, sql6, sql7, sql8, sql9, sql10
|
||||
|
||||
);
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user