[refactor](Nereids): polish code SemiJoinLogicalJoinTranspose. (#17740)

This commit is contained in:
jakevin
2023-03-14 12:48:58 +08:00
committed by GitHub
parent 77ab2fac20
commit be3a7e69cd
6 changed files with 43 additions and 84 deletions

View File

@ -216,7 +216,7 @@ public enum RuleType {
// exploration rules
TEST_EXPLORATION(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_COMMUTE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_OUTER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION),
@ -225,7 +225,10 @@ public enum RuleType {
LOGICAL_OUTER_JOIN_ASSOC_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_LOGICAL_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE_LEFT_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE_RIGHT_PROJECT(RuleTypeClass.EXPLORATION),
@ -234,7 +237,6 @@ public enum RuleType {
LOGICAL_INNER_JOIN_LEFT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_INNER_JOIN_RIGHT_ASSOCIATIVE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN(RuleTypeClass.EXPLORATION),
PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN(RuleTypeClass.EXPLORATION),

View File

@ -67,7 +67,7 @@ public class JoinCommute extends OneExplorationRuleFactory {
}
return newJoin;
}).toRule(RuleType.LOGICAL_JOIN_COMMUTATE);
}).toRule(RuleType.LOGICAL_JOIN_COMMUTE);
}
enum SwapType {

View File

@ -58,7 +58,6 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
&& (topJoin.left().getJoinType().isInnerJoin()
|| topJoin.left().getJoinType().isLeftOuterJoin()
|| topJoin.left().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.whenNot(topJoin -> topJoin.hasJoinHint() || topJoin.left().hasJoinHint())
.whenNot(LogicalJoin::isMarkJoin)

View File

@ -17,14 +17,10 @@
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -34,9 +30,6 @@ import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@ -65,7 +58,6 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
&& (topJoin.left().child().getJoinType().isInnerJoin()
|| topJoin.left().child().getJoinType().isLeftOuterJoin()
|| topJoin.left().child().getJoinType().isRightOuterJoin())))
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(join -> JoinReorderUtils.isAllSlotProject(join.left()))
@ -76,9 +68,8 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
// push topSemiJoin down project, so we need replace conjuncts by project.
Pair<List<Expression>, List<Expression>> conjuncts = replaceConjuncts(topSemiJoin, project);
Set<ExprId> conjunctsIds = Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
Set<ExprId> conjunctsIds = Stream.concat(topSemiJoin.getHashJoinConjuncts().stream(),
topSemiJoin.getOtherJoinConjuncts().stream())
.flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
ContainsType containsType = containsChildren(conjunctsIds, a.getOutputExprIdSet(),
b.getOutputExprIdSet());
@ -99,8 +90,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
// RIGHT_OUTER_JOIN should be eliminated in rewrite phase
Preconditions.checkState(bottomJoin.getJoinType() != JoinType.RIGHT_OUTER_JOIN);
Plan newBottomSemiJoin = topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
a, c);
Plan newBottomSemiJoin = topSemiJoin.withChildren(a, c);
Plan newTopJoin = bottomJoin.withChildren(newBottomSemiJoin, b);
return project.withChildren(newTopJoin);
} else {
@ -112,37 +102,20 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
*/
// LEFT_OUTER_JOIN should be eliminated in rewrite phase
Preconditions.checkState(bottomJoin.getJoinType() != JoinType.LEFT_OUTER_JOIN);
Plan newBottomSemiJoin = topSemiJoin.withConjunctsChildren(conjuncts.first, conjuncts.second,
b, c);
Plan newBottomSemiJoin = topSemiJoin.withChildren(b, c);
Plan newTopJoin = bottomJoin.withChildren(a, newBottomSemiJoin);
return project.withChildren(newTopJoin);
}
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}
private Pair<List<Expression>, List<Expression>> replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
LogicalProject<? extends Plan> project) {
Map<ExprId, Slot> outputToInput = new HashMap<>();
for (NamedExpression outputExpr : project.getProjects()) {
Set<Slot> usedSlots = outputExpr.getInputSlots();
Preconditions.checkState(usedSlots.size() == 1);
Slot inputSlot = usedSlots.iterator().next();
outputToInput.put(outputExpr.getExprId(), inputSlot);
}
List<Expression> topHashConjuncts =
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), outputToInput);
List<Expression> topOtherConjuncts =
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), outputToInput);
return Pair.of(topHashConjuncts, topOtherConjuncts);
}
enum ContainsType {
LEFT, RIGHT, ALL
}

View File

@ -73,7 +73,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
Plan newBottomJoin = topJoin.withChildren(a, c);
Plan newTopJoin = bottomJoin.withChildren(newBottomJoin, b);
return newTopJoin;
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANSPOSE);
}
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {

View File

@ -20,11 +20,10 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
@ -33,13 +32,13 @@ import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class SemiJoinLogicalJoinTransposeProjectTest {
class SemiJoinLogicalJoinTransposeProjectTest implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
public void testSemiJoinLogicalTransposeProjectLAsscom() {
void testSemiJoinLogicalTransposeProjectLAsscom() {
/*-
* topSemiJoin project
* / \ |
@ -57,28 +56,21 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.applyExploration(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left();
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right();
Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName());
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
Assertions.assertEquals("t2", newTopJoinRight.getTable().getName());
});
.matchesExploration(
logicalProject(
innerLogicalJoin(
leftSemiLogicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
)
)
);
}
@Test
public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
void testSemiJoinLogicalTransposeProjectLAsscomFail() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
@ -94,15 +86,15 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
}
@Test
public void testSemiJoinLogicalTransposeProjectAll() {
void testSemiJoinLogicalTransposeProjectAll() {
/*-
* topSemiJoin project
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
* / \ / \
* A B B C
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
@ -112,23 +104,16 @@ public class SemiJoinLogicalJoinTransposeProjectTest {
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.applyExploration(SemiJoinLogicalJoinTransposeProject.ALL.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right();
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left();
Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName());
Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName());
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
});
.matchesExploration(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
leftSemiLogicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
)
)
)
);
}
}