[feature](Nereids): normalize join condition after expanding or condition NLJ (#22555)

This commit is contained in:
谢健
2023-08-04 13:37:37 +08:00
committed by GitHub
parent d5a21de796
commit 658d75c816
5 changed files with 303 additions and 53 deletions

View File

@ -102,7 +102,6 @@ public class RuleSet {
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
.add(new MergeProjectsCBO())
.add(new OrExpansion())
.build();
public static final List<Rule> OTHER_REORDER_RULES = planRuleFactories()
@ -117,6 +116,7 @@ public class RuleSet {
.add(PushdownProjectThroughSemiJoin.INSTANCE)
.add(TransposeAggSemiJoin.INSTANCE)
.add(TransposeAggSemiJoinProject.INSTANCE)
.add(OrExpansion.INSTANCE)
.build();
public static final List<RuleFactory> PUSH_DOWN_FILTERS = ImmutableList.of(
@ -204,7 +204,6 @@ public class RuleSet {
.build();
public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
.addAll(EXPLORATION_RULES)
.add(JoinCommute.BUSHY.build())
.build();

View File

@ -21,7 +21,7 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.rules.rewrite.PushdownExpressionsInHashCondition;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -31,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.JoinUtils;
@ -52,6 +53,7 @@ import java.util.stream.Collectors;
* HJ(cond1) HJ(cond2 and !cond1)
*/
public class OrExpansion extends OneExplorationRuleFactory {
public static final OrExpansion INSTANCE = new OrExpansion();
@Override
public Rule build() {
@ -68,10 +70,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
// We pick the first or condition that can be split to EqualTo expressions.
for (Expression expr : otherConditions) {
Pair<List<Expression>, List<Expression>> pair = expandExpr(expr, join);
// TODO: Now we don't support expand condition with complex expression
if (pair.second.isEmpty() && pair.first.stream()
.noneMatch(e -> !((EqualTo) e).left().isSlot()
&& !((EqualTo) e).right().isSlot())) {
if (pair.second.isEmpty()) {
disjunctions = pair.first;
otherConditions.remove(expr);
break;
@ -136,10 +135,16 @@ public class OrExpansion extends OneExplorationRuleFactory {
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
.collect(Collectors.toList());
// TODO: normalize join condition
LogicalJoin<? extends Plan, ? extends Plan> newJoin = join.withJoinConjuncts(hashCond, otherCond)
.withChildren(Lists.newArrayList(left, right));
joins.add(newJoin);
if (newJoin.getHashJoinConjuncts().stream().anyMatch(equalTo ->
equalTo.children().stream().anyMatch(e -> !(e instanceof Slot)))) {
Plan plan = PushdownExpressionsInHashCondition.pushDownHashExpression(newJoin);
plan = new LogicalProject<>(new ArrayList<>(newJoin.getOutput()), plan);
joins.add(plan);
} else {
joins.add(newJoin);
}
}
return joins;
}

View File

@ -64,47 +64,54 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory {
return logicalJoin()
.when(join -> join.getHashJoinConjuncts().stream().anyMatch(equalTo ->
equalTo.children().stream().anyMatch(e -> !(e instanceof Slot))))
.then(join -> {
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
join.getHashJoinConjuncts().forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualTo);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) conjunct, join.left().getOutputSet());
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
});
// add other conjuncts used slots to project exprs
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
join.getOtherJoinConjuncts().stream().flatMap(conjunct ->
conjunct.getInputSlots().stream()
).forEach(slot -> {
if (leftExprIdSet.contains(slot.getExprId())) {
// belong to left child
leftProjectExprs.add(slot);
} else {
// belong to right child
rightProjectExprs.add(slot);
}
});
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
.map(equalTo -> equalTo.withChildren(equalTo.children()
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())
.collect(ImmutableList.toImmutableList())))
.collect(ImmutableList.toImmutableList());
return join.withHashJoinConjunctsAndChildren(
newHashConjuncts,
createChildProjectPlan(join.left(), join, leftProjectExprs),
createChildProjectPlan(join.right(), join, rightProjectExprs));
}).toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
.then(PushdownExpressionsInHashCondition::pushDownHashExpression)
.toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
}
private LogicalProject createChildProjectPlan(Plan plan, LogicalJoin join,
/**
* push down complex expression in hash condition
*/
public static Plan pushDownHashExpression(LogicalJoin<? extends Plan, ? extends Plan> join) {
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
join.getHashJoinConjuncts().forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualTo);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder(
(EqualTo) conjunct, join.left().getOutputSet());
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
});
// add other conjuncts used slots to project exprs
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
join.getOtherJoinConjuncts().stream().flatMap(conjunct ->
conjunct.getInputSlots().stream()
).forEach(slot -> {
if (leftExprIdSet.contains(slot.getExprId())) {
// belong to left child
leftProjectExprs.add(slot);
} else {
// belong to right child
rightProjectExprs.add(slot);
}
});
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
.map(equalTo -> equalTo.withChildren(equalTo.children()
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())
.collect(ImmutableList.toImmutableList())))
.collect(ImmutableList.toImmutableList());
return join.withHashJoinConjunctsAndChildren(
newHashConjuncts,
createChildProjectPlan(join.left(), join, leftProjectExprs),
createChildProjectPlan(join.right(), join, rightProjectExprs));
}
private static LogicalProject createChildProjectPlan(Plan plan, LogicalJoin join,
Set<NamedExpression> conditionUsedExprs) {
Set<NamedExpression> intersectionSlots = Sets.newHashSet(plan.getOutputSet());
intersectionSlots.retainAll(join.getOutputSet());
@ -113,7 +120,7 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory {
.collect(ImmutableList.toImmutableList()), plan);
}
private void generateReplaceMapAndProjectExprs(Expression expr, Map<Expression, NamedExpression> replaceMap,
private static void generateReplaceMapAndProjectExprs(Expression expr, Map<Expression, NamedExpression> replaceMap,
Set<NamedExpression> projects) {
if (expr instanceof SlotReference) {
projects.add((SlotReference) expr);