branch-2.1: [fix](Nereids) we should also push down expr in join's mark conjuncts #50886 (#50955)

Cherry-picked from #50886

Co-authored-by: morrySnow <zhangwenxin@selectdb.com>
This commit is contained in:
github-actions[bot]
2025-05-15 22:46:54 +08:00
committed by GitHub
parent 01f70deb8b
commit 82d1375dc5
3 changed files with 144 additions and 15 deletions

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
@ -62,8 +63,19 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin()
.when(join -> join.getHashJoinConjuncts().stream().anyMatch(equalTo ->
equalTo.children().stream().anyMatch(e -> !(e instanceof Slot))))
.when(join -> {
boolean needProcessHashConjuncts = join.getHashJoinConjuncts().stream()
.anyMatch(equalTo -> equalTo.children().stream()
.anyMatch(e -> !(e instanceof Slot)));
List<Slot> leftSlots = join.left().getOutput();
List<Slot> rightSlots = join.right().getOutput();
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
leftSlots, rightSlots, join.getMarkJoinConjuncts());
boolean needProcessMarkConjuncts = pair.first.stream()
.anyMatch(equalTo -> equalTo.children().stream()
.anyMatch(e -> !(e instanceof Slot)));
return needProcessHashConjuncts || needProcessMarkConjuncts;
})
.then(PushDownExpressionsInHashCondition::pushDownHashExpression)
.toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
}
@ -75,15 +87,20 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
LogicalJoin<? extends Plan, ? extends Plan> join) {
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
Map<Expression, NamedExpression> replaceMap = Maps.newHashMap();
join.getHashJoinConjuncts().forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet());
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, rightProjectExprs);
});
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
.map(equalTo -> equalTo.withChildren(equalTo.children()
.stream().map(expr -> replaceMap.get(expr).toSlot())
.collect(ImmutableList.toImmutableList())))
.collect(ImmutableList.toImmutableList());
// add other conjuncts used slots to project exprs
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
@ -100,7 +117,28 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
});
// add mark conjuncts used slots to project exprs
join.getMarkJoinConjuncts().stream().flatMap(conjunct ->
// if mark conjuncts could be hash condition, normalize it
List<Slot> leftSlots = join.left().getOutput();
List<Slot> rightSlots = join.right().getOutput();
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(leftSlots,
rightSlots, join.getMarkJoinConjuncts());
pair.first.forEach(conjunct -> {
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
// doesn't swap the two sides.
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet());
generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, leftProjectExprs);
generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, rightProjectExprs);
});
ImmutableList.Builder<Expression> newMarkConjunctsBuilder = ImmutableList.builder();
pair.first.stream()
.map(equalTo -> equalTo.withChildren(equalTo.children()
.stream().map(expr -> replaceMap.get(expr).toSlot())
.collect(ImmutableList.toImmutableList())))
.forEach(newMarkConjunctsBuilder::add);
newMarkConjunctsBuilder.addAll(pair.second);
pair.second.stream().flatMap(conjunct ->
conjunct.getInputSlots().stream()
).forEach(slot -> {
if (leftExprIdSet.contains(slot.getExprId())) {
@ -111,14 +149,9 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
rightProjectExprs.add(slot);
}
});
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
.map(equalTo -> equalTo.withChildren(equalTo.children()
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())
.collect(ImmutableList.toImmutableList())))
.collect(ImmutableList.toImmutableList());
return join.withHashJoinConjunctsAndChildren(
return join.withHashAndMarkJoinConjunctsAndChildren(
newHashConjuncts,
newMarkConjunctsBuilder.build(),
createChildProjectPlan(join.left(), join, leftProjectExprs),
createChildProjectPlan(join.right(), join, rightProjectExprs), join.getJoinReorderContext());

View File

@ -385,8 +385,21 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
children, otherJoinReorderContext);
}
public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren(
List<Expression> hashJoinConjuncts, Plan left, Plan right, JoinReorderContext otherJoinReorderContext) {
/**
* Creates a new LogicalJoin with updated hash join conjuncts, mark join conjuncts, and child plans.
*
* @param hashJoinConjuncts the list of hash join conjuncts used for hash-based join conditions.
* @param markJoinConjuncts the list of mark join conjuncts used for marking specific join conditions.
* These are typically used in semi-join or anti-join scenarios to track
* whether a condition is satisfied.
* @param left the left child plan.
* @param right the right child plan.
* @param otherJoinReorderContext the context for join reordering.
* @return a new LogicalJoin instance with the specified parameters.
*/
public LogicalJoin<Plan, Plan> withHashAndMarkJoinConjunctsAndChildren(
List<Expression> hashJoinConjuncts, List<Expression> markJoinConjuncts,
Plan left, Plan right, JoinReorderContext otherJoinReorderContext) {
Preconditions.checkArgument(children.size() == 2);
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),

View File

@ -20,7 +20,26 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Positive;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;
@ -184,4 +203,68 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im
)
);
}
@Test
public void testPushDownMarkConjuncts() {
Plan left = new LogicalOneRowRelation(new RelationId(1),
ImmutableList.of(new Alias(new ExprId(1), new IntegerLiteral(1), "a")));
Plan right = new LogicalOneRowRelation(new RelationId(2),
ImmutableList.of(new Alias(new ExprId(2), new IntegerLiteral(2), "b")));
Expression sameLeft = new Abs(left.getOutput().get(0));
Expression sameRight = new Positive(right.getOutput().get(0));
Expression hashLeft = new Cast(sameLeft, StringType.INSTANCE);
Expression hashRight = new Cast(sameRight, StringType.INSTANCE);
Expression markLeft = new Cast(sameLeft, BigIntType.INSTANCE);
Expression markRight = new Cast(sameRight, BigIntType.INSTANCE);
LogicalJoin<?, ?> plan = new LogicalJoin<>(
JoinType.INNER_JOIN,
left,
right,
new JoinReorderContext()
);
Expression sameConjuncts = new EqualTo(sameLeft, sameRight);
Expression hashConjuncts = new EqualTo(hashLeft, hashRight);
Expression markConjuncts = new EqualTo(markLeft, markRight);
Expression otherConjuncts = new Add(left.getOutput().get(0), new IntegerLiteral(1));
plan = plan.withJoinConjuncts(ImmutableList.of(sameConjuncts, hashConjuncts), ImmutableList.of(otherConjuncts),
ImmutableList.of(sameConjuncts, markConjuncts, otherConjuncts),
new JoinReorderContext());
PlanChecker.from(connectContext, plan).applyTopDown(new PushDownExpressionsInHashCondition())
.matches(logicalJoin(logicalProject(logicalOneRowRelation())
.when(p -> p.getProjects().size() == 4
&& p.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast).map(UnaryNode::child)
.filter(sameLeft::equals).count() == 1
&& p.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast).map(UnaryNode::child)
.filter(markLeft::equals).count() == 1
&& p.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast).map(UnaryNode::child)
.filter(hashLeft::equals).count() == 1),
logicalProject(logicalOneRowRelation())
.when(p -> p.getProjects().size() == 4
&& p.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast).map(UnaryNode::child)
.filter(sameRight::equals).count() == 1
&& p.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast).map(UnaryNode::child)
.filter(markRight::equals).count() == 1
&& p.getProjects().stream().filter(Alias.class::isInstance)
.map(Alias.class::cast).map(UnaryNode::child)
.filter(hashRight::equals).count() == 1)
).when(j -> j.getMarkJoinConjuncts().size() == 3
&& j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance)
.allMatch(e -> ((EqualTo) e).left() instanceof SlotReference
&& ((EqualTo) e).right() instanceof SlotReference)
&& j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 2)
.when(j -> j.getHashJoinConjuncts().size() == 2
&& j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance)
.allMatch(e -> ((EqualTo) e).left() instanceof SlotReference
&& ((EqualTo) e).right() instanceof SlotReference)
&& j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 2));
}
}