[feature](Nereids): normalize join condition after expanding or condition NLJ (#22555)
This commit is contained in:
@ -102,7 +102,6 @@ public class RuleSet {
|
||||
|
||||
public static final List<Rule> EXPLORATION_RULES = planRuleFactories()
|
||||
.add(new MergeProjectsCBO())
|
||||
.add(new OrExpansion())
|
||||
.build();
|
||||
|
||||
public static final List<Rule> OTHER_REORDER_RULES = planRuleFactories()
|
||||
@ -117,6 +116,7 @@ public class RuleSet {
|
||||
.add(PushdownProjectThroughSemiJoin.INSTANCE)
|
||||
.add(TransposeAggSemiJoin.INSTANCE)
|
||||
.add(TransposeAggSemiJoinProject.INSTANCE)
|
||||
.add(OrExpansion.INSTANCE)
|
||||
.build();
|
||||
|
||||
public static final List<RuleFactory> PUSH_DOWN_FILTERS = ImmutableList.of(
|
||||
@ -204,7 +204,6 @@ public class RuleSet {
|
||||
.build();
|
||||
|
||||
public static final List<Rule> DPHYP_REORDER_RULES = ImmutableList.<Rule>builder()
|
||||
.addAll(EXPLORATION_RULES)
|
||||
.add(JoinCommute.BUSHY.build())
|
||||
.build();
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.StatementContext;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.rules.rewrite.PushdownExpressionsInHashCondition;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Not;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
@ -31,6 +31,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.JoinUtils;
|
||||
@ -52,6 +53,7 @@ import java.util.stream.Collectors;
|
||||
* HJ(cond1) HJ(cond2 and !cond1)
|
||||
*/
|
||||
public class OrExpansion extends OneExplorationRuleFactory {
|
||||
public static final OrExpansion INSTANCE = new OrExpansion();
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
@ -68,10 +70,7 @@ public class OrExpansion extends OneExplorationRuleFactory {
|
||||
// We pick the first or condition that can be split to EqualTo expressions.
|
||||
for (Expression expr : otherConditions) {
|
||||
Pair<List<Expression>, List<Expression>> pair = expandExpr(expr, join);
|
||||
// TODO: Now we don't support expand condition with complex expression
|
||||
if (pair.second.isEmpty() && pair.first.stream()
|
||||
.noneMatch(e -> !((EqualTo) e).left().isSlot()
|
||||
&& !((EqualTo) e).right().isSlot())) {
|
||||
if (pair.second.isEmpty()) {
|
||||
disjunctions = pair.first;
|
||||
otherConditions.remove(expr);
|
||||
break;
|
||||
@ -136,10 +135,16 @@ public class OrExpansion extends OneExplorationRuleFactory {
|
||||
.map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s))
|
||||
.collect(Collectors.toList());
|
||||
|
||||
// TODO: normalize join condition
|
||||
LogicalJoin<? extends Plan, ? extends Plan> newJoin = join.withJoinConjuncts(hashCond, otherCond)
|
||||
.withChildren(Lists.newArrayList(left, right));
|
||||
joins.add(newJoin);
|
||||
if (newJoin.getHashJoinConjuncts().stream().anyMatch(equalTo ->
|
||||
equalTo.children().stream().anyMatch(e -> !(e instanceof Slot)))) {
|
||||
Plan plan = PushdownExpressionsInHashCondition.pushDownHashExpression(newJoin);
|
||||
plan = new LogicalProject<>(new ArrayList<>(newJoin.getOutput()), plan);
|
||||
joins.add(plan);
|
||||
} else {
|
||||
joins.add(newJoin);
|
||||
}
|
||||
}
|
||||
return joins;
|
||||
}
|
||||
|
||||
@ -64,47 +64,54 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory {
|
||||
return logicalJoin()
|
||||
.when(join -> join.getHashJoinConjuncts().stream().anyMatch(equalTo ->
|
||||
equalTo.children().stream().anyMatch(e -> !(e instanceof Slot))))
|
||||
.then(join -> {
|
||||
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
|
||||
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
|
||||
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
|
||||
join.getHashJoinConjuncts().forEach(conjunct -> {
|
||||
Preconditions.checkArgument(conjunct instanceof EqualTo);
|
||||
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
|
||||
// doesn't swap the two sides.
|
||||
conjunct = JoinUtils.swapEqualToForChildrenOrder(
|
||||
(EqualTo) conjunct, join.left().getOutputSet());
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
|
||||
});
|
||||
|
||||
// add other conjuncts used slots to project exprs
|
||||
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
|
||||
join.getOtherJoinConjuncts().stream().flatMap(conjunct ->
|
||||
conjunct.getInputSlots().stream()
|
||||
).forEach(slot -> {
|
||||
if (leftExprIdSet.contains(slot.getExprId())) {
|
||||
// belong to left child
|
||||
leftProjectExprs.add(slot);
|
||||
} else {
|
||||
// belong to right child
|
||||
rightProjectExprs.add(slot);
|
||||
}
|
||||
});
|
||||
|
||||
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
|
||||
.map(equalTo -> equalTo.withChildren(equalTo.children()
|
||||
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())
|
||||
.collect(ImmutableList.toImmutableList())))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return join.withHashJoinConjunctsAndChildren(
|
||||
newHashConjuncts,
|
||||
createChildProjectPlan(join.left(), join, leftProjectExprs),
|
||||
createChildProjectPlan(join.right(), join, rightProjectExprs));
|
||||
}).toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
|
||||
.then(PushdownExpressionsInHashCondition::pushDownHashExpression)
|
||||
.toRule(RuleType.PUSHDOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
|
||||
}
|
||||
|
||||
private LogicalProject createChildProjectPlan(Plan plan, LogicalJoin join,
|
||||
/**
|
||||
* push down complex expression in hash condition
|
||||
*/
|
||||
public static Plan pushDownHashExpression(LogicalJoin<? extends Plan, ? extends Plan> join) {
|
||||
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
|
||||
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
|
||||
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
|
||||
join.getHashJoinConjuncts().forEach(conjunct -> {
|
||||
Preconditions.checkArgument(conjunct instanceof EqualTo);
|
||||
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
|
||||
// doesn't swap the two sides.
|
||||
conjunct = JoinUtils.swapEqualToForChildrenOrder(
|
||||
(EqualTo) conjunct, join.left().getOutputSet());
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
|
||||
});
|
||||
|
||||
// add other conjuncts used slots to project exprs
|
||||
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
|
||||
join.getOtherJoinConjuncts().stream().flatMap(conjunct ->
|
||||
conjunct.getInputSlots().stream()
|
||||
).forEach(slot -> {
|
||||
if (leftExprIdSet.contains(slot.getExprId())) {
|
||||
// belong to left child
|
||||
leftProjectExprs.add(slot);
|
||||
} else {
|
||||
// belong to right child
|
||||
rightProjectExprs.add(slot);
|
||||
}
|
||||
});
|
||||
|
||||
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
|
||||
.map(equalTo -> equalTo.withChildren(equalTo.children()
|
||||
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())
|
||||
.collect(ImmutableList.toImmutableList())))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return join.withHashJoinConjunctsAndChildren(
|
||||
newHashConjuncts,
|
||||
createChildProjectPlan(join.left(), join, leftProjectExprs),
|
||||
createChildProjectPlan(join.right(), join, rightProjectExprs));
|
||||
|
||||
}
|
||||
|
||||
private static LogicalProject createChildProjectPlan(Plan plan, LogicalJoin join,
|
||||
Set<NamedExpression> conditionUsedExprs) {
|
||||
Set<NamedExpression> intersectionSlots = Sets.newHashSet(plan.getOutputSet());
|
||||
intersectionSlots.retainAll(join.getOutputSet());
|
||||
@ -113,7 +120,7 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory {
|
||||
.collect(ImmutableList.toImmutableList()), plan);
|
||||
}
|
||||
|
||||
private void generateReplaceMapAndProjectExprs(Expression expr, Map<Expression, NamedExpression> replaceMap,
|
||||
private static void generateReplaceMapAndProjectExprs(Expression expr, Map<Expression, NamedExpression> replaceMap,
|
||||
Set<NamedExpression> projects) {
|
||||
if (expr instanceof SlotReference) {
|
||||
projects.add((SlotReference) expr);
|
||||
|
||||
Reference in New Issue
Block a user