[feature](Nereids): semi join transpose (#12590)

* [feature](Nereids): semi join transpose and enable ZIG_ZAG join reorder.
This commit is contained in:
jakevin
2022-09-15 21:32:50 +08:00
committed by GitHub
parent c6c84a2784
commit db8bc80c36
11 changed files with 394 additions and 71 deletions

View File

@ -21,6 +21,9 @@ import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
@ -55,6 +58,9 @@ public class RuleSet {
.add(JoinCommuteProject.LEFT_DEEP)
.add(JoinLAsscom.INNER)
.add(JoinLAsscomProject.INNER)
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
.add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(new PushdownFilterThroughProject())
.add(new MergeConsecutiveProjects())
.build();
@ -140,6 +146,11 @@ public class RuleSet {
return this;
}
public RuleFactories addAll(List<Rule> rules) {
this.rules.addAll(rules);
return this;
}
public List<Rule> build() {
return rules.build();
}

View File

@ -109,8 +109,7 @@ public enum RuleType {
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
// Pushdown filter
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
LOGICAL_LIMIT_TO_LOGICAL_EMPTY_RELATION_RULE(RuleTypeClass.REWRITE),
SWAP_LIMIT_PROJECT(RuleTypeClass.REWRITE),
PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
REWRITE_SENTINEL(RuleTypeClass.REWRITE),
// limit push down
@ -122,7 +121,11 @@ public enum RuleType {
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_L_ASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
// implementation rules
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),

View File

@ -75,6 +75,6 @@ public class JoinLAsscomProject extends OneExplorationRuleFactory {
return null;
}
return helper.newTopJoin();
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM_PROJECT);
}
}

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -27,6 +28,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import java.util.List;
import java.util.Set;
/**
@ -42,9 +44,21 @@ import java.util.Set;
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
public static final SemiJoinLogicalJoinTranspose LEFT_DEEP = new SemiJoinLogicalJoinTranspose(true);
public static final SemiJoinLogicalJoinTranspose ALL = new SemiJoinLogicalJoinTranspose(false);
private final boolean leftDeep;
public SemiJoinLogicalJoinTranspose(boolean leftDeep) {
this.leftDeep = leftDeep;
}
@Override
public Rule build() {
return leftSemiLogicalJoin(logicalJoin(), group())
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topSemiJoin.left();
@ -52,7 +66,14 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
boolean lasscom = bottomJoin.getOutputSet().containsAll(a.getOutput());
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
Set<Slot> aOutputSet = a.getOutputSet();
boolean lasscom = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
}
if (lasscom) {
/*
@ -81,20 +102,27 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
}
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE);
}
// bottomJoin just return A OR B, else return false.
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
Set<Slot> bottomOutputSet = topJoin.left().getOutputSet();
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topSemiJoin) {
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
Set<Slot> aOutputSet = topJoin.left().left().getOutputSet();
Set<Slot> bOutputSet = topJoin.left().right().getOutputSet();
List<Slot> aOutput = topSemiJoin.left().left().getOutput();
List<Slot> bOutput = topSemiJoin.left().right().getOutput();
boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet, aOutputSet);
boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet, bOutputSet);
Preconditions.checkState(isProjectA || isProjectB, "join output must contain child");
return !(isProjectA && isProjectB);
boolean hashContainsA = false;
boolean hashContainsB = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA;
hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB;
}
if (leftDeep && hashContainsB) {
return false;
}
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
return !(hashContainsA && hashContainsB);
}
}

View File

@ -20,15 +20,17 @@ package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
@ -45,9 +47,20 @@ import java.util.Set;
* which operands actually participate in the semi-join.
*/
public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFactory {
public static final SemiJoinLogicalJoinTransposeProject LEFT_DEEP = new SemiJoinLogicalJoinTransposeProject(true);
public static final SemiJoinLogicalJoinTransposeProject ALL = new SemiJoinLogicalJoinTransposeProject(false);
private final boolean leftDeep;
public SemiJoinLogicalJoinTransposeProject(boolean leftDeep) {
this.leftDeep = leftDeep;
}
@Override
public Rule build() {
return leftSemiLogicalJoin(logicalProject(logicalJoin()), group())
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
.when(this::conditionChecker)
.then(topSemiJoin -> {
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
@ -56,67 +69,77 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
GroupPlan b = bottomJoin.right();
GroupPlan c = topSemiJoin.right();
boolean lasscom = a.getOutputSet().containsAll(project.getOutput());
Set<Slot> aOutputSet = a.getOutputSet();
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
boolean lasscom = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
}
if (lasscom) {
/*-
* topSemiJoin newTopProject
* / \ |
* topSemiJoin project
* / \ |
* project C newTopJoin
* | -> / \
* bottomJoin newBottomSemiJoin B
* bottomJoin newBottomSemiJoin B
* / \ / \
* A B aNewProject C
* |
* A
* A B A C
*/
List<NamedExpression> projects = project.getProjects();
LogicalProject<GroupPlan> aNewProject = new LogicalProject<>(projects, a);
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinCondition(), aNewProject, c);
LogicalJoin<LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>, GroupPlan> newTopJoin
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b);
return new LogicalProject<>(projects, newTopJoin);
topSemiJoin.getOtherJoinCondition(), a, c);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(),
newBottomSemiJoin, b);
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
} else {
/*-
* topSemiJoin newTopProject
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B bNewProject C
* |
* B
* topSemiJoin project
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
*/
List<NamedExpression> projects = project.getProjects();
LogicalProject<GroupPlan> bNewProject = new LogicalProject<>(projects, b);
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
topSemiJoin.getOtherJoinCondition(), bNewProject, c);
topSemiJoin.getOtherJoinCondition(), b, c);
LogicalJoin<GroupPlan, LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>> newTopJoin
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
return new LogicalProject<>(projects, newTopJoin);
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(),
a, newBottomSemiJoin);
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
}
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
}
// bottomJoin just return A OR B, else return false.
// project of bottomJoin just return A OR B, else return false.
private boolean conditionChecker(
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) {
Set<Slot> projectOutputSet = topJoin.left().getOutputSet();
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) {
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
Set<Slot> aOutputSet = topJoin.left().child().left().getOutputSet();
Set<Slot> bOutputSet = topJoin.left().child().right().getOutputSet();
List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet, aOutputSet);
boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet, bOutputSet);
Preconditions.checkState(isProjectA || isProjectB, "project must contain child");
return !(isProjectA && isProjectB);
boolean hashContainsA = false;
boolean hashContainsB = false;
for (Expression hashJoinConjunct : hashJoinConjuncts) {
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA;
hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB;
}
if (leftDeep && hashContainsB) {
return false;
}
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
return !(hashContainsA && hashContainsB);
}
}

View File

@ -37,6 +37,7 @@ import java.util.Set;
* LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z))
*/
public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
public static final SemiJoinSemiJoinTranspose INSTANCE = new SemiJoinSemiJoinTranspose();
public static Set<Pair<JoinType, JoinType>> typeSet = ImmutableSet.of(
Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN),
@ -69,7 +70,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
newBottomJoin, b);
return newTopJoin;
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
}
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {

View File

@ -54,6 +54,6 @@ public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
return new LogicalLimit<LogicalProject<GroupPlan>>(logicalLimit.getLimit(),
logicalLimit.getOffset(), new LogicalProject<>(logicalProject.getProjects(),
logicalLimit.child()));
}).toRule(RuleType.SWAP_LIMIT_PROJECT);
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGHT_LIMIT);
}
}

View File

@ -0,0 +1,134 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class SemiJoinLogicalJoinTransposeProjectTest {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
public void testSemiJoinLogicalTransposeProjectLAsscom() {
/*-
* topSemiJoin project
* / \ |
* project C newTopJoin
* | -> / \
* bottomJoin newBottomSemiJoin B
* / \ / \
* A B A C
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.project(ImmutableList.of(0))
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left();
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right();
Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName());
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
Assertions.assertEquals("t2", newTopJoinRight.getTable().getName());
});
}
@Test
public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0)) // t2.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(1, root.getLogicalExpressions().size());
});
}
@Test
public void testSemiJoinLogicalTransposeProjectAll() {
/*-
* topSemiJoin project
* / \ |
* project C newTopJoin
* | / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0)) // t2.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(SemiJoinLogicalJoinTransposeProject.ALL.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right();
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left();
Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName());
Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName());
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
});
}
}

View File

@ -0,0 +1,126 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class SemiJoinLogicalJoinTransposeTest {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
public void testSemiJoinLogicalTransposeLAsscom() {
/*
* topSemiJoin newTopJoin
* / \ / \
* bottomJoin C --> newBottomSemiJoin B
* / \ / \
* A B A C
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan;
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left();
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right();
Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName());
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
Assertions.assertEquals("t2", newTopJoinRight.getTable().getName());
});
}
@Test
public void testSemiJoinLogicalTransposeLAsscomFail() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0)) // t2.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(1, root.getLogicalExpressions().size());
});
}
@Test
public void testSemiJoinLogicalTransposeAll() {
/*
* topSemiJoin newTopJoin
* / \ / \
* bottomJoin C --> A newBottomSemiJoin
* / \ / \
* A B B C
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0)) // t2.id = t3.id
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform(SemiJoinLogicalJoinTranspose.ALL.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan;
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right();
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left();
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName());
Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName());
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
});
}
}

View File

@ -29,11 +29,10 @@ import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
public class SemiJoinTransposeTest {
public class SemiJoinSemiJoinTransposeTest {
public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@ -41,21 +40,19 @@ public class SemiJoinTransposeTest {
@Test
public void testSemiJoinLogicalTransposeCommute() {
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
.project(ImmutableList.of(0))
.hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0))
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.transform((new SemiJoinLogicalJoinTransposeProject()).build())
.transform(SemiJoinSemiJoinTranspose.INSTANCE.build())
.checkMemo(memo -> {
Group root = memo.getRoot();
Assertions.assertEquals(2, root.getLogicalExpressions().size());
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
Plan join = memo.copyOut(root.getLogicalExpressions().get(1), false);
Plan join = plan.child(0);
Assertions.assertTrue(join instanceof LogicalJoin);
Assertions.assertEquals(JoinType.INNER_JOIN, ((LogicalJoin<?, ?>) join).getJoinType());
Assertions.assertEquals(JoinType.LEFT_ANTI_JOIN, ((LogicalJoin<?, ?>) join).getJoinType());
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
((LogicalJoin<?, ?>) ((LogicalJoin<?, ?>) join).left()).getJoinType());
});

View File

@ -58,10 +58,10 @@ public class LogicalPlanBuilder {
return from(project);
}
public LogicalPlanBuilder project(List<Integer> slots) {
public LogicalPlanBuilder project(List<Integer> slotsIndex) {
List<NamedExpression> projectExprs = Lists.newArrayList();
for (int i = 0; i < slots.size(); i++) {
projectExprs.add(this.plan.getOutput().get(i));
for (Integer index : slotsIndex) {
projectExprs.add(this.plan.getOutput().get(index));
}
LogicalProject<LogicalPlan> project = new LogicalProject<>(projectExprs, this.plan);
return from(project);