Cherry-picked from #50886 Co-authored-by: morrySnow <zhangwenxin@selectdb.com>
This commit is contained in:
committed by
GitHub
parent
01f70deb8b
commit
82d1375dc5
@ -17,6 +17,7 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
@ -62,8 +63,19 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJoin()
|
||||
.when(join -> join.getHashJoinConjuncts().stream().anyMatch(equalTo ->
|
||||
equalTo.children().stream().anyMatch(e -> !(e instanceof Slot))))
|
||||
.when(join -> {
|
||||
boolean needProcessHashConjuncts = join.getHashJoinConjuncts().stream()
|
||||
.anyMatch(equalTo -> equalTo.children().stream()
|
||||
.anyMatch(e -> !(e instanceof Slot)));
|
||||
List<Slot> leftSlots = join.left().getOutput();
|
||||
List<Slot> rightSlots = join.right().getOutput();
|
||||
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
|
||||
leftSlots, rightSlots, join.getMarkJoinConjuncts());
|
||||
boolean needProcessMarkConjuncts = pair.first.stream()
|
||||
.anyMatch(equalTo -> equalTo.children().stream()
|
||||
.anyMatch(e -> !(e instanceof Slot)));
|
||||
return needProcessHashConjuncts || needProcessMarkConjuncts;
|
||||
})
|
||||
.then(PushDownExpressionsInHashCondition::pushDownHashExpression)
|
||||
.toRule(RuleType.PUSH_DOWN_EXPRESSIONS_IN_HASH_CONDITIONS);
|
||||
}
|
||||
@ -75,15 +87,20 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
|
||||
LogicalJoin<? extends Plan, ? extends Plan> join) {
|
||||
Set<NamedExpression> leftProjectExprs = Sets.newHashSet();
|
||||
Set<NamedExpression> rightProjectExprs = Sets.newHashSet();
|
||||
Map<Expression, NamedExpression> exprReplaceMap = Maps.newHashMap();
|
||||
Map<Expression, NamedExpression> replaceMap = Maps.newHashMap();
|
||||
join.getHashJoinConjuncts().forEach(conjunct -> {
|
||||
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
|
||||
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
|
||||
// doesn't swap the two sides.
|
||||
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet());
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs);
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs);
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, leftProjectExprs);
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, rightProjectExprs);
|
||||
});
|
||||
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
|
||||
.map(equalTo -> equalTo.withChildren(equalTo.children()
|
||||
.stream().map(expr -> replaceMap.get(expr).toSlot())
|
||||
.collect(ImmutableList.toImmutableList())))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
|
||||
// add other conjuncts used slots to project exprs
|
||||
Set<ExprId> leftExprIdSet = join.left().getOutputExprIdSet();
|
||||
@ -100,7 +117,28 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
|
||||
});
|
||||
|
||||
// add mark conjuncts used slots to project exprs
|
||||
join.getMarkJoinConjuncts().stream().flatMap(conjunct ->
|
||||
// if mark conjuncts could be hash condition, normalize it
|
||||
List<Slot> leftSlots = join.left().getOutput();
|
||||
List<Slot> rightSlots = join.right().getOutput();
|
||||
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(leftSlots,
|
||||
rightSlots, join.getMarkJoinConjuncts());
|
||||
pair.first.forEach(conjunct -> {
|
||||
Preconditions.checkArgument(conjunct instanceof EqualPredicate);
|
||||
// sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it
|
||||
// doesn't swap the two sides.
|
||||
conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet());
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(0), replaceMap, leftProjectExprs);
|
||||
generateReplaceMapAndProjectExprs(conjunct.child(1), replaceMap, rightProjectExprs);
|
||||
});
|
||||
ImmutableList.Builder<Expression> newMarkConjunctsBuilder = ImmutableList.builder();
|
||||
pair.first.stream()
|
||||
.map(equalTo -> equalTo.withChildren(equalTo.children()
|
||||
.stream().map(expr -> replaceMap.get(expr).toSlot())
|
||||
.collect(ImmutableList.toImmutableList())))
|
||||
.forEach(newMarkConjunctsBuilder::add);
|
||||
newMarkConjunctsBuilder.addAll(pair.second);
|
||||
|
||||
pair.second.stream().flatMap(conjunct ->
|
||||
conjunct.getInputSlots().stream()
|
||||
).forEach(slot -> {
|
||||
if (leftExprIdSet.contains(slot.getExprId())) {
|
||||
@ -111,14 +149,9 @@ public class PushDownExpressionsInHashCondition extends OneRewriteRuleFactory {
|
||||
rightProjectExprs.add(slot);
|
||||
}
|
||||
});
|
||||
|
||||
List<Expression> newHashConjuncts = join.getHashJoinConjuncts().stream()
|
||||
.map(equalTo -> equalTo.withChildren(equalTo.children()
|
||||
.stream().map(expr -> exprReplaceMap.get(expr).toSlot())
|
||||
.collect(ImmutableList.toImmutableList())))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return join.withHashJoinConjunctsAndChildren(
|
||||
return join.withHashAndMarkJoinConjunctsAndChildren(
|
||||
newHashConjuncts,
|
||||
newMarkConjunctsBuilder.build(),
|
||||
createChildProjectPlan(join.left(), join, leftProjectExprs),
|
||||
createChildProjectPlan(join.right(), join, rightProjectExprs), join.getJoinReorderContext());
|
||||
|
||||
|
||||
@ -385,8 +385,21 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
|
||||
children, otherJoinReorderContext);
|
||||
}
|
||||
|
||||
public LogicalJoin<Plan, Plan> withHashJoinConjunctsAndChildren(
|
||||
List<Expression> hashJoinConjuncts, Plan left, Plan right, JoinReorderContext otherJoinReorderContext) {
|
||||
/**
|
||||
* Creates a new LogicalJoin with updated hash join conjuncts, mark join conjuncts, and child plans.
|
||||
*
|
||||
* @param hashJoinConjuncts the list of hash join conjuncts used for hash-based join conditions.
|
||||
* @param markJoinConjuncts the list of mark join conjuncts used for marking specific join conditions.
|
||||
* These are typically used in semi-join or anti-join scenarios to track
|
||||
* whether a condition is satisfied.
|
||||
* @param left the left child plan.
|
||||
* @param right the right child plan.
|
||||
* @param otherJoinReorderContext the context for join reordering.
|
||||
* @return a new LogicalJoin instance with the specified parameters.
|
||||
*/
|
||||
public LogicalJoin<Plan, Plan> withHashAndMarkJoinConjunctsAndChildren(
|
||||
List<Expression> hashJoinConjuncts, List<Expression> markJoinConjuncts,
|
||||
Plan left, Plan right, JoinReorderContext otherJoinReorderContext) {
|
||||
Preconditions.checkArgument(children.size() == 2);
|
||||
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
|
||||
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
|
||||
|
||||
@ -20,7 +20,26 @@ package org.apache.doris.nereids.rules.rewrite;
|
||||
import org.apache.doris.nereids.NereidsPlanner;
|
||||
import org.apache.doris.nereids.parser.NereidsParser;
|
||||
import org.apache.doris.nereids.properties.PhysicalProperties;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
|
||||
import org.apache.doris.nereids.trees.UnaryNode;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Positive;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.RelationId;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.utframe.TestWithFeService;
|
||||
@ -184,4 +203,68 @@ public class PushDownExpressionsInHashConditionTest extends TestWithFeService im
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPushDownMarkConjuncts() {
|
||||
Plan left = new LogicalOneRowRelation(new RelationId(1),
|
||||
ImmutableList.of(new Alias(new ExprId(1), new IntegerLiteral(1), "a")));
|
||||
Plan right = new LogicalOneRowRelation(new RelationId(2),
|
||||
ImmutableList.of(new Alias(new ExprId(2), new IntegerLiteral(2), "b")));
|
||||
Expression sameLeft = new Abs(left.getOutput().get(0));
|
||||
Expression sameRight = new Positive(right.getOutput().get(0));
|
||||
Expression hashLeft = new Cast(sameLeft, StringType.INSTANCE);
|
||||
Expression hashRight = new Cast(sameRight, StringType.INSTANCE);
|
||||
Expression markLeft = new Cast(sameLeft, BigIntType.INSTANCE);
|
||||
Expression markRight = new Cast(sameRight, BigIntType.INSTANCE);
|
||||
|
||||
LogicalJoin<?, ?> plan = new LogicalJoin<>(
|
||||
JoinType.INNER_JOIN,
|
||||
left,
|
||||
right,
|
||||
new JoinReorderContext()
|
||||
);
|
||||
|
||||
Expression sameConjuncts = new EqualTo(sameLeft, sameRight);
|
||||
Expression hashConjuncts = new EqualTo(hashLeft, hashRight);
|
||||
Expression markConjuncts = new EqualTo(markLeft, markRight);
|
||||
Expression otherConjuncts = new Add(left.getOutput().get(0), new IntegerLiteral(1));
|
||||
|
||||
plan = plan.withJoinConjuncts(ImmutableList.of(sameConjuncts, hashConjuncts), ImmutableList.of(otherConjuncts),
|
||||
ImmutableList.of(sameConjuncts, markConjuncts, otherConjuncts),
|
||||
new JoinReorderContext());
|
||||
|
||||
PlanChecker.from(connectContext, plan).applyTopDown(new PushDownExpressionsInHashCondition())
|
||||
.matches(logicalJoin(logicalProject(logicalOneRowRelation())
|
||||
.when(p -> p.getProjects().size() == 4
|
||||
&& p.getProjects().stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).map(UnaryNode::child)
|
||||
.filter(sameLeft::equals).count() == 1
|
||||
&& p.getProjects().stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).map(UnaryNode::child)
|
||||
.filter(markLeft::equals).count() == 1
|
||||
&& p.getProjects().stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).map(UnaryNode::child)
|
||||
.filter(hashLeft::equals).count() == 1),
|
||||
logicalProject(logicalOneRowRelation())
|
||||
.when(p -> p.getProjects().size() == 4
|
||||
&& p.getProjects().stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).map(UnaryNode::child)
|
||||
.filter(sameRight::equals).count() == 1
|
||||
&& p.getProjects().stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).map(UnaryNode::child)
|
||||
.filter(markRight::equals).count() == 1
|
||||
&& p.getProjects().stream().filter(Alias.class::isInstance)
|
||||
.map(Alias.class::cast).map(UnaryNode::child)
|
||||
.filter(hashRight::equals).count() == 1)
|
||||
).when(j -> j.getMarkJoinConjuncts().size() == 3
|
||||
&& j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance)
|
||||
.allMatch(e -> ((EqualTo) e).left() instanceof SlotReference
|
||||
&& ((EqualTo) e).right() instanceof SlotReference)
|
||||
&& j.getMarkJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 2)
|
||||
.when(j -> j.getHashJoinConjuncts().size() == 2
|
||||
&& j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance)
|
||||
.allMatch(e -> ((EqualTo) e).left() instanceof SlotReference
|
||||
&& ((EqualTo) e).right() instanceof SlotReference)
|
||||
&& j.getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).count() == 2));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user