[feature](Nereids): semi join transpose (#12590)
* [feature](Nereids): semi join transpose and enable ZIG_ZAG join reorder.
This commit is contained in:
@ -21,6 +21,9 @@ import org.apache.doris.nereids.rules.exploration.join.JoinCommute;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom;
|
||||
import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
|
||||
@ -55,6 +58,9 @@ public class RuleSet {
|
||||
.add(JoinCommuteProject.LEFT_DEEP)
|
||||
.add(JoinLAsscom.INNER)
|
||||
.add(JoinLAsscomProject.INNER)
|
||||
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
|
||||
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
|
||||
.add(SemiJoinSemiJoinTranspose.INSTANCE)
|
||||
.add(new PushdownFilterThroughProject())
|
||||
.add(new MergeConsecutiveProjects())
|
||||
.build();
|
||||
@ -140,6 +146,11 @@ public class RuleSet {
|
||||
return this;
|
||||
}
|
||||
|
||||
public RuleFactories addAll(List<Rule> rules) {
|
||||
this.rules.addAll(rules);
|
||||
return this;
|
||||
}
|
||||
|
||||
public List<Rule> build() {
|
||||
return rules.build();
|
||||
}
|
||||
|
||||
@ -109,8 +109,7 @@ public enum RuleType {
|
||||
OLAP_SCAN_PARTITION_PRUNE(RuleTypeClass.REWRITE),
|
||||
// Pushdown filter
|
||||
PUSHDOWN_FILTER_THROUGH_PROJET(RuleTypeClass.REWRITE),
|
||||
LOGICAL_LIMIT_TO_LOGICAL_EMPTY_RELATION_RULE(RuleTypeClass.REWRITE),
|
||||
SWAP_LIMIT_PROJECT(RuleTypeClass.REWRITE),
|
||||
PUSHDOWN_PROJECT_THROUGHT_LIMIT(RuleTypeClass.REWRITE),
|
||||
REWRITE_SENTINEL(RuleTypeClass.REWRITE),
|
||||
|
||||
// limit push down
|
||||
@ -122,7 +121,11 @@ public enum RuleType {
|
||||
LOGICAL_JOIN_COMMUTATE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_LEFT_JOIN_ASSOCIATIVE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_L_ASSCOM(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_L_ASSCOM_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_JOIN_EXCHANGE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION),
|
||||
|
||||
// implementation rules
|
||||
LOGICAL_ONE_ROW_RELATION_TO_PHYSICAL_ONE_ROW_RELATION(RuleTypeClass.IMPLEMENTATION),
|
||||
|
||||
@ -75,6 +75,6 @@ public class JoinLAsscomProject extends OneExplorationRuleFactory {
|
||||
return null;
|
||||
}
|
||||
return helper.newTopJoin();
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM_PROJECT);
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
@ -27,6 +28,7 @@ import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
@ -42,9 +44,21 @@ import java.util.Set;
|
||||
* which operands actually participate in the semi-join.
|
||||
*/
|
||||
public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
|
||||
|
||||
public static final SemiJoinLogicalJoinTranspose LEFT_DEEP = new SemiJoinLogicalJoinTranspose(true);
|
||||
|
||||
public static final SemiJoinLogicalJoinTranspose ALL = new SemiJoinLogicalJoinTranspose(false);
|
||||
|
||||
private final boolean leftDeep;
|
||||
|
||||
public SemiJoinLogicalJoinTranspose(boolean leftDeep) {
|
||||
this.leftDeep = leftDeep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return leftSemiLogicalJoin(logicalJoin(), group())
|
||||
.whenNot(topJoin -> topJoin.left().getJoinType().isSemiOrAntiJoin())
|
||||
.when(this::conditionChecker)
|
||||
.then(topSemiJoin -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topSemiJoin.left();
|
||||
@ -52,7 +66,14 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
|
||||
GroupPlan b = bottomJoin.right();
|
||||
GroupPlan c = topSemiJoin.right();
|
||||
|
||||
boolean lasscom = bottomJoin.getOutputSet().containsAll(a.getOutput());
|
||||
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
|
||||
Set<Slot> aOutputSet = a.getOutputSet();
|
||||
|
||||
boolean lasscom = false;
|
||||
for (Expression hashJoinConjunct : hashJoinConjuncts) {
|
||||
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
|
||||
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
|
||||
}
|
||||
|
||||
if (lasscom) {
|
||||
/*
|
||||
@ -81,20 +102,27 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory {
|
||||
return new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
|
||||
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
|
||||
}
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE);
|
||||
}
|
||||
|
||||
// bottomJoin just return A OR B, else return false.
|
||||
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
|
||||
Set<Slot> bottomOutputSet = topJoin.left().getOutputSet();
|
||||
private boolean conditionChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topSemiJoin) {
|
||||
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
|
||||
|
||||
Set<Slot> aOutputSet = topJoin.left().left().getOutputSet();
|
||||
Set<Slot> bOutputSet = topJoin.left().right().getOutputSet();
|
||||
List<Slot> aOutput = topSemiJoin.left().left().getOutput();
|
||||
List<Slot> bOutput = topSemiJoin.left().right().getOutput();
|
||||
|
||||
boolean isProjectA = !ExpressionUtils.isIntersecting(bottomOutputSet, aOutputSet);
|
||||
boolean isProjectB = !ExpressionUtils.isIntersecting(bottomOutputSet, bOutputSet);
|
||||
|
||||
Preconditions.checkState(isProjectA || isProjectB, "join output must contain child");
|
||||
return !(isProjectA && isProjectB);
|
||||
boolean hashContainsA = false;
|
||||
boolean hashContainsB = false;
|
||||
for (Expression hashJoinConjunct : hashJoinConjuncts) {
|
||||
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
|
||||
hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA;
|
||||
hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB;
|
||||
}
|
||||
if (leftDeep && hashContainsB) {
|
||||
return false;
|
||||
}
|
||||
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
|
||||
return !(hashContainsA && hashContainsB);
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,15 +20,17 @@ package org.apache.doris.nereids.rules.exploration.join;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
@ -45,9 +47,20 @@ import java.util.Set;
|
||||
* which operands actually participate in the semi-join.
|
||||
*/
|
||||
public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFactory {
|
||||
public static final SemiJoinLogicalJoinTransposeProject LEFT_DEEP = new SemiJoinLogicalJoinTransposeProject(true);
|
||||
|
||||
public static final SemiJoinLogicalJoinTransposeProject ALL = new SemiJoinLogicalJoinTransposeProject(false);
|
||||
|
||||
private final boolean leftDeep;
|
||||
|
||||
public SemiJoinLogicalJoinTransposeProject(boolean leftDeep) {
|
||||
this.leftDeep = leftDeep;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return leftSemiLogicalJoin(logicalProject(logicalJoin()), group())
|
||||
.whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin())
|
||||
.when(this::conditionChecker)
|
||||
.then(topSemiJoin -> {
|
||||
LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left();
|
||||
@ -56,67 +69,77 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto
|
||||
GroupPlan b = bottomJoin.right();
|
||||
GroupPlan c = topSemiJoin.right();
|
||||
|
||||
boolean lasscom = a.getOutputSet().containsAll(project.getOutput());
|
||||
Set<Slot> aOutputSet = a.getOutputSet();
|
||||
|
||||
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
|
||||
|
||||
boolean lasscom = false;
|
||||
for (Expression hashJoinConjunct : hashJoinConjuncts) {
|
||||
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
|
||||
lasscom = ExpressionUtils.isIntersecting(usedSlot, aOutputSet) || lasscom;
|
||||
}
|
||||
|
||||
if (lasscom) {
|
||||
/*-
|
||||
* topSemiJoin newTopProject
|
||||
* / \ |
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | -> / \
|
||||
* bottomJoin newBottomSemiJoin B
|
||||
* bottomJoin newBottomSemiJoin B
|
||||
* / \ / \
|
||||
* A B aNewProject C
|
||||
* |
|
||||
* A
|
||||
* A B A C
|
||||
*/
|
||||
List<NamedExpression> projects = project.getProjects();
|
||||
LogicalProject<GroupPlan> aNewProject = new LogicalProject<>(projects, a);
|
||||
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
|
||||
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
|
||||
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
|
||||
topSemiJoin.getOtherJoinCondition(), aNewProject, c);
|
||||
LogicalJoin<LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>, GroupPlan> newTopJoin
|
||||
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
|
||||
bottomJoin.getOtherJoinCondition(), newBottomSemiJoin, b);
|
||||
return new LogicalProject<>(projects, newTopJoin);
|
||||
topSemiJoin.getOtherJoinCondition(), a, c);
|
||||
|
||||
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
|
||||
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(),
|
||||
newBottomSemiJoin, b);
|
||||
|
||||
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
|
||||
} else {
|
||||
/*-
|
||||
* topSemiJoin newTopProject
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | / \
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B bNewProject C
|
||||
* |
|
||||
* B
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | / \
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B B C
|
||||
*/
|
||||
List<NamedExpression> projects = project.getProjects();
|
||||
LogicalProject<GroupPlan> bNewProject = new LogicalProject<>(projects, b);
|
||||
LogicalJoin<LogicalProject<GroupPlan>, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
|
||||
LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>(
|
||||
topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(),
|
||||
topSemiJoin.getOtherJoinCondition(), bNewProject, c);
|
||||
topSemiJoin.getOtherJoinCondition(), b, c);
|
||||
|
||||
LogicalJoin<GroupPlan, LogicalJoin<LogicalProject<GroupPlan>, GroupPlan>> newTopJoin
|
||||
= new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(),
|
||||
bottomJoin.getOtherJoinCondition(), a, newBottomSemiJoin);
|
||||
return new LogicalProject<>(projects, newTopJoin);
|
||||
LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(),
|
||||
bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinCondition(),
|
||||
a, newBottomSemiJoin);
|
||||
|
||||
return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
|
||||
}
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
|
||||
}
|
||||
|
||||
// bottomJoin just return A OR B, else return false.
|
||||
// project of bottomJoin just return A OR B, else return false.
|
||||
private boolean conditionChecker(
|
||||
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) {
|
||||
Set<Slot> projectOutputSet = topJoin.left().getOutputSet();
|
||||
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) {
|
||||
List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts();
|
||||
|
||||
Set<Slot> aOutputSet = topJoin.left().child().left().getOutputSet();
|
||||
Set<Slot> bOutputSet = topJoin.left().child().right().getOutputSet();
|
||||
List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
|
||||
List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
|
||||
|
||||
boolean isProjectA = !ExpressionUtils.isIntersecting(projectOutputSet, aOutputSet);
|
||||
boolean isProjectB = !ExpressionUtils.isIntersecting(projectOutputSet, bOutputSet);
|
||||
|
||||
Preconditions.checkState(isProjectA || isProjectB, "project must contain child");
|
||||
return !(isProjectA && isProjectB);
|
||||
boolean hashContainsA = false;
|
||||
boolean hashContainsB = false;
|
||||
for (Expression hashJoinConjunct : hashJoinConjuncts) {
|
||||
Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance);
|
||||
hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA;
|
||||
hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB;
|
||||
}
|
||||
if (leftDeep && hashContainsB) {
|
||||
return false;
|
||||
}
|
||||
Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child");
|
||||
return !(hashContainsA && hashContainsB);
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,6 +37,7 @@ import java.util.Set;
|
||||
* LEFT-Semi/ANTI(X, LEFT-Semi/ANTI(Y, Z))
|
||||
*/
|
||||
public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
|
||||
public static final SemiJoinSemiJoinTranspose INSTANCE = new SemiJoinSemiJoinTranspose();
|
||||
|
||||
public static Set<Pair<JoinType, JoinType>> typeSet = ImmutableSet.of(
|
||||
Pair.of(JoinType.LEFT_SEMI_JOIN, JoinType.LEFT_SEMI_JOIN),
|
||||
@ -69,7 +70,7 @@ public class SemiJoinSemiJoinTranspose extends OneExplorationRuleFactory {
|
||||
newBottomJoin, b);
|
||||
|
||||
return newTopJoin;
|
||||
}).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM);
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE);
|
||||
}
|
||||
|
||||
private boolean typeChecker(LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> topJoin) {
|
||||
|
||||
@ -54,6 +54,6 @@ public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
|
||||
return new LogicalLimit<LogicalProject<GroupPlan>>(logicalLimit.getLimit(),
|
||||
logicalLimit.getOffset(), new LogicalProject<>(logicalProject.getProjects(),
|
||||
logicalLimit.child()));
|
||||
}).toRule(RuleType.SWAP_LIMIT_PROJECT);
|
||||
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGHT_LIMIT);
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,134 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SemiJoinLogicalJoinTransposeProjectTest {
|
||||
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeProjectLAsscom() {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | -> / \
|
||||
* bottomJoin newBottomSemiJoin B
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.project(ImmutableList.of(0))
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
|
||||
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left();
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
|
||||
|
||||
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
|
||||
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
|
||||
LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right();
|
||||
|
||||
Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
|
||||
Assertions.assertEquals("t2", newTopJoinRight.getTable().getName());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeProjectLAsscomFail() {
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0)) // t2.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(1, root.getLogicalExpressions().size());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeProjectAll() {
|
||||
/*-
|
||||
* topSemiJoin project
|
||||
* / \ |
|
||||
* project C newTopJoin
|
||||
* | / \
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B B C
|
||||
*/
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.project(ImmutableList.of(0, 2)) // t1.id, t2.id
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 0)) // t2.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(SemiJoinLogicalJoinTransposeProject.ALL.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan.child(0);
|
||||
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right();
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
|
||||
|
||||
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
|
||||
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
|
||||
LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left();
|
||||
|
||||
Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,126 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.memo.Group;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SemiJoinLogicalJoinTransposeTest {
|
||||
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
private static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeLAsscom() {
|
||||
/*
|
||||
* topSemiJoin newTopJoin
|
||||
* / \ / \
|
||||
* bottomJoin C --> newBottomSemiJoin B
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) // t1.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan;
|
||||
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.left();
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
|
||||
|
||||
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
|
||||
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
|
||||
LogicalOlapScan newTopJoinRight = (LogicalOlapScan) newTopJoin.right();
|
||||
|
||||
Assertions.assertEquals("t1", newBottomJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
|
||||
Assertions.assertEquals("t2", newTopJoinRight.getTable().getName());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeLAsscomFail() {
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0)) // t2.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(SemiJoinLogicalJoinTranspose.LEFT_DEEP.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(1, root.getLogicalExpressions().size());
|
||||
});
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeAll() {
|
||||
/*
|
||||
* topSemiJoin newTopJoin
|
||||
* / \ / \
|
||||
* bottomJoin C --> A newBottomSemiJoin
|
||||
* / \ / \
|
||||
* A B B C
|
||||
*/
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(2, 0)) // t2.id = t3.id
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform(SemiJoinLogicalJoinTranspose.ALL.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
LogicalJoin<?, ?> newTopJoin = (LogicalJoin<?, ?>) plan;
|
||||
LogicalJoin<?, ?> newBottomJoin = (LogicalJoin<?, ?>) newTopJoin.right();
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, newTopJoin.getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN, newBottomJoin.getJoinType());
|
||||
|
||||
LogicalOlapScan newTopJoinLeft = (LogicalOlapScan) newTopJoin.left();
|
||||
LogicalOlapScan newBottomJoinLeft = (LogicalOlapScan) newBottomJoin.left();
|
||||
LogicalOlapScan newBottomJoinRight = (LogicalOlapScan) newBottomJoin.right();
|
||||
|
||||
Assertions.assertEquals("t1", newTopJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t2", newBottomJoinLeft.getTable().getName());
|
||||
Assertions.assertEquals("t3", newBottomJoinRight.getTable().getName());
|
||||
});
|
||||
}
|
||||
}
|
||||
@ -29,11 +29,10 @@ import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SemiJoinTransposeTest {
|
||||
public class SemiJoinSemiJoinTransposeTest {
|
||||
public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
@ -41,21 +40,19 @@ public class SemiJoinTransposeTest {
|
||||
@Test
|
||||
public void testSemiJoinLogicalTransposeCommute() {
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.INNER_JOIN, Pair.of(0, 0))
|
||||
.project(ImmutableList.of(0))
|
||||
.hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0))
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0))
|
||||
.build();
|
||||
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.transform((new SemiJoinLogicalJoinTransposeProject()).build())
|
||||
.transform(SemiJoinSemiJoinTranspose.INSTANCE.build())
|
||||
.checkMemo(memo -> {
|
||||
Group root = memo.getRoot();
|
||||
Assertions.assertEquals(2, root.getLogicalExpressions().size());
|
||||
Plan plan = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
Plan join = memo.copyOut(root.getLogicalExpressions().get(1), false);
|
||||
|
||||
Plan join = plan.child(0);
|
||||
Assertions.assertTrue(join instanceof LogicalJoin);
|
||||
Assertions.assertEquals(JoinType.INNER_JOIN, ((LogicalJoin<?, ?>) join).getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_ANTI_JOIN, ((LogicalJoin<?, ?>) join).getJoinType());
|
||||
Assertions.assertEquals(JoinType.LEFT_SEMI_JOIN,
|
||||
((LogicalJoin<?, ?>) ((LogicalJoin<?, ?>) join).left()).getJoinType());
|
||||
});
|
||||
@ -58,10 +58,10 @@ public class LogicalPlanBuilder {
|
||||
return from(project);
|
||||
}
|
||||
|
||||
public LogicalPlanBuilder project(List<Integer> slots) {
|
||||
public LogicalPlanBuilder project(List<Integer> slotsIndex) {
|
||||
List<NamedExpression> projectExprs = Lists.newArrayList();
|
||||
for (int i = 0; i < slots.size(); i++) {
|
||||
projectExprs.add(this.plan.getOutput().get(i));
|
||||
for (Integer index : slotsIndex) {
|
||||
projectExprs.add(this.plan.getOutput().get(index));
|
||||
}
|
||||
LogicalProject<LogicalPlan> project = new LogicalProject<>(projectExprs, this.plan);
|
||||
return from(project);
|
||||
|
||||
Reference in New Issue
Block a user