[feature](nereids) add rule for semi/anti join exploration, when there is project between them (#13756)
This commit is contained in:
@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
|
||||
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
|
||||
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
|
||||
@ -66,6 +67,7 @@ public class RuleSet {
|
||||
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
|
||||
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
|
||||
.add(SemiJoinSemiJoinTranspose.INSTANCE)
|
||||
.add(SemiJoinSemiJoinTransposeProject.INSTANCE)
|
||||
.add(new AggregateDisassemble())
|
||||
.add(new PushdownFilterThroughProject())
|
||||
.add(new MergeProjects())
|
||||
|
||||
@ -156,6 +156,7 @@ public enum RuleType {
|
||||
LOGICAL_ASSERT_NUM_ROWS_TO_PHYSICAL_ASSERT_NUM_ROWS(RuleTypeClass.IMPLEMENTATION),
|
||||
IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION),
|
||||
|
||||
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT(RuleTypeClass.EXPLORATION),
|
||||
// sentinel, use to count rules
|
||||
SENTINEL(RuleTypeClass.SENTINEL),
|
||||
;
|
||||
|
||||
@ -0,0 +1,99 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* rule for semi-semi transpose
|
||||
*/
|
||||
public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory {
|
||||
public static final SemiJoinSemiJoinTransposeProject INSTANCE = new SemiJoinSemiJoinTransposeProject();
|
||||
|
||||
/*
|
||||
* topSemi newTopSemi
|
||||
* / \ / \
|
||||
* abProject C acProject B
|
||||
* / ──► /
|
||||
* bottomSemi newBottomSemi
|
||||
* / \ / \
|
||||
* A B A C
|
||||
*/
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalJoin(logicalProject(logicalJoin()), group())
|
||||
.when(this::typeChecker)
|
||||
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child()))
|
||||
.then(topSemi -> {
|
||||
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();
|
||||
LogicalProject abProject = topSemi.left();
|
||||
GroupPlan a = bottomSemi.left();
|
||||
GroupPlan b = bottomSemi.right();
|
||||
GroupPlan c = topSemi.right();
|
||||
Set<Slot> aOutputSet = a.getOutputSet();
|
||||
Set<NamedExpression> acProjects = new HashSet<NamedExpression>(abProject.getProjects());
|
||||
|
||||
bottomSemi.getHashJoinConjuncts().stream().forEach(
|
||||
expression -> {
|
||||
expression.getInputSlots().stream().forEach(
|
||||
slot -> {
|
||||
if (aOutputSet.contains(slot)) {
|
||||
acProjects.add(slot);
|
||||
}
|
||||
});
|
||||
}
|
||||
);
|
||||
LogicalJoin newBottomSemi = new LogicalJoin(topSemi.getJoinType(), topSemi.getHashJoinConjuncts(),
|
||||
topSemi.getOtherJoinConjuncts(), a, c,
|
||||
bottomSemi.getJoinReorderContext());
|
||||
newBottomSemi.getJoinReorderContext().setHasCommute(false);
|
||||
newBottomSemi.getJoinReorderContext().setHasLAsscom(false);
|
||||
LogicalProject acProject = new LogicalProject(acProjects.stream().collect(Collectors.toList()),
|
||||
newBottomSemi);
|
||||
LogicalJoin newTopSemi = new LogicalJoin(bottomSemi.getJoinType(),
|
||||
bottomSemi.getHashJoinConjuncts(), bottomSemi.getOtherJoinConjuncts(),
|
||||
acProject, b,
|
||||
topSemi.getJoinReorderContext());
|
||||
newTopSemi.getJoinReorderContext().setHasLAsscom(true);
|
||||
//return newTopSemi;
|
||||
if (topSemi.getLogicalProperties().equals(newTopSemi)) {
|
||||
return newTopSemi;
|
||||
} else {
|
||||
return new LogicalProject<>(new ArrayList<>(topSemi.getOutput()), newTopSemi);
|
||||
}
|
||||
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT);
|
||||
}
|
||||
|
||||
public boolean typeChecker(LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) {
|
||||
return SemiJoinSemiJoinTranspose.VALID_TYPE_PAIR_SET
|
||||
.contains(Pair.of(topJoin.getJoinType(), topJoin.left().child().getJoinType()));
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,77 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.exploration.join;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.trees.plans.JoinType;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
|
||||
import org.apache.doris.nereids.util.LogicalPlanBuilder;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.nereids.util.PlanConstructor;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
public class SemiJoinSemiJoinTransposeProjectTest implements PatternMatchSupported {
|
||||
public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
|
||||
public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
|
||||
public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
|
||||
|
||||
@Test
|
||||
public void testSemiProjectSemiCommute() {
|
||||
/*
|
||||
* t1.name=t3.name t1.id=t2.id
|
||||
* topJoin newTopJoin
|
||||
* / \ / \
|
||||
* project t3 project t2
|
||||
* t1.name t1.name, t1.id
|
||||
* | |
|
||||
* t1.id=t2.id t1.name=t3.name
|
||||
* bottomJoin --> newBottomJoin
|
||||
* / \ / \
|
||||
* t1 t2 t1 t3
|
||||
*/
|
||||
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
|
||||
.hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0))
|
||||
.project(ImmutableList.of(1))
|
||||
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1))
|
||||
.build();
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
|
||||
.applyExploration(SemiJoinSemiJoinTransposeProject.INSTANCE.build())
|
||||
.printlnExploration()
|
||||
.matchesExploration(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalProject(
|
||||
logicalJoin(
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
|
||||
).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN)
|
||||
).when(project -> project.getProjects().size() == 2
|
||||
&& project.getProjects().get(0).getName().equals("id")
|
||||
&& project.getProjects().get(1).getName().equals("name")
|
||||
),
|
||||
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
|
||||
).when(join -> join.getJoinType() == JoinType.LEFT_ANTI_JOIN)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user