[refactor](Nereids): polish code SemiJoinLogicalJoinTranspose. (#17740)
This commit is contained in:
@ -216,7 +216,7 @@ public enum RuleType {
|
||||
|
||||
// exploration rules
|
||||
TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_OUTER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
|
||||
@ -225,7 +225,10 @@ public enum RuleType {
|
||||
LOGICAL_OUTER_JOIN_ASSOC_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_EXCHANGE_LEFT_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_EXCHANGE_RIGHT_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
@ -234,7 +237,6 @@ public enum RuleType {
|
||||
LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
|
||||
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
|
||||
}
|
||||
|
||||
return newJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
|
||||
}).toRule(RuleType.LOGICAL_JOIN_COMMUTE);
|
||||
}
|
||||
|
||||
enum SwapType {
|
||||
|
||||
@ -58,7 +58,6 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
|
||||
&& (topJoin.left().getJoinType().isInnerJoin()
|
||||
|| topJoin.left().getJoinType().isLeftOuterJoin()
|
||||
|| topJoin.left().getJoinType().isRightOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
|
||||
.when(this::conditionChecker)
|
||||
.whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint())
|
||||
.whenNot(LogicalJoin::isMarkJoin)
|
||||
|
||||
@ -17,14 +17,10 @@
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.ExprId;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
@ -34,9 +30,6 @@ import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
@ -65,7 +58,6 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
&& (topJoin.left().child().getJoinType().isInnerJoin()
|
||||
|| topJoin.left().child().getJoinType().isLeftOuterJoin()
|
||||
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
|
||||
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
|
||||
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
|
||||
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
|
||||
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
|
||||
@ -76,9 +68,8 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
GroupPlan b = bottomJoin.right();
|
||||
GroupPlan c = topSemiJoin.right();
|
||||
|
||||
// push topSemiJoin down project, so we need replace conjuncts by project.
|
||||
Pair<List<Expression>, List<Expression>> conjuncts = replaceConjuncts(topSemiJoin, project);
|
||||
Set<ExprId> conjunctsIds = Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
|
||||
Set<ExprId> conjunctsIds = Stream.concat(topSemiJoin.getHashJoinConjuncts().stream(),
|
||||
topSemiJoin.getOtherJoinConjuncts().stream())
|
||||
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
|
||||
ContainsType containsType = containsChildren(conjunctsIds, a.getOutputExprIdSet(),
|
||||
b.getOutputExprIdSet());
|
||||
@ -99,8 +90,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
// RIGHT_OUTER_JOIN should be eliminated in rewrite phase
|
||||
Preconditions.checkState(bottomJoin.getJoinType() != JoinType.RIGHT_OUTER_JOIN);
|
||||
|
||||
Plan newBottomSemiJoin = topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
|
||||
a, c);
|
||||
Plan newBottomSemiJoin = topSemiJoin.withChildren(a, c);
|
||||
Plan newTopJoin = bottomJoin.withChildren(newBottomSemiJoin, b);
|
||||
return project.withChildren(newTopJoin);
|
||||
} else {
|
||||
@ -112,37 +102,20 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | / \
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B B C
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B B C
|
||||
*/
|
||||
// LEFT_OUTER_JOIN should be eliminated in rewrite phase
|
||||
Preconditions.checkState(bottomJoin.getJoinType() != JoinType.LEFT_OUTER_JOIN);
|
||||
|
||||
Plan newBottomSemiJoin = topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
|
||||
b, c);
|
||||
Plan newBottomSemiJoin = topSemiJoin.withChildren(b, c);
|
||||
Plan newTopJoin = bottomJoin.withChildren(a, newBottomSemiJoin);
|
||||
return project.withChildren(newTopJoin);
|
||||
}
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
|
||||
}
|
||||
|
||||
private Pair<List<Expression>, List<Expression>> replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
|
||||
LogicalProject<? extends Plan> project) {
|
||||
Map<ExprId, Slot> outputToInput = new HashMap<>();
|
||||
for (NamedExpression outputExpr : project.getProjects()) {
|
||||
Set<Slot> usedSlots = outputExpr.getInputSlots();
|
||||
Preconditions.checkState(usedSlots.size() == 1);
|
||||
Slot inputSlot = usedSlots.iterator().next();
|
||||
outputToInput.put(outputExpr.getExprId(), inputSlot);
|
||||
}
|
||||
List<Expression> topHashConjuncts =
|
||||
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), outputToInput);
|
||||
List<Expression> topOtherConjuncts =
|
||||
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), outputToInput);
|
||||
return Pair.of(topHashConjuncts, topOtherConjuncts);
|
||||
}
|
||||
|
||||
enum ContainsType {
|
||||
LEFT, RIGHT, ALL
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
|
||||
Plan newBottomJoin = topJoin.withChildren(a, c);
|
||||
Plan newTopJoin = bottomJoin.withChildren(newBottomJoin, b);
|
||||
return newTopJoin;
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE);
|
||||
}
|
||||
|
||||
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
|
||||
|
||||
@ -20,11 +20,10 @@ package org.apache.doris.nereids.rules.exploration.join;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
@ -33,13 +32,13 @@ import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SemiJoinLogicalJoinTransposeProjectTest {
|
||||
class SemiJoinLogicalJoinTransposeProjectTest implements MemoPatternMatchSupported {
|
||||
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeProjectLAsscom() {
|
||||
void testSemiJoinLogicalTransposeProjectLAsscom() {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
@ -57,28 +56,21 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.applyExploration(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
|
||||
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left();
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
|
||||
|
||||
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
|
||||
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
|
||||
LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right();
|
||||
|
||||
Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
|
||||
Assertions.assertEquals("t2", newTopJoinRight.getTable().getName());
|
||||
});
|
||||
.matchesExploration(
|
||||
logicalProject(
|
||||
innerLogicalJoin(
|
||||
leftSemiLogicalJoin(
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
|
||||
),
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
|
||||
void testSemiJoinLogicalTransposeProjectLAsscomFail() {
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
|
||||
@ -94,15 +86,15 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeProjectAll() {
|
||||
void testSemiJoinLogicalTransposeProjectAll() {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | / \
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B B C
|
||||
* / \ / \
|
||||
* A B B C
|
||||
*/
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
@ -112,23 +104,16 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.applyExploration(SemiJoinLogicalJoinTransposeProject.ALL.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
|
||||
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right();
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
|
||||
|
||||
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
|
||||
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
|
||||
LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left();
|
||||
|
||||
Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
|
||||
});
|
||||
.matchesExploration(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
|
||||
leftSemiLogicalJoin(
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")),
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user