[feature](nereids) add rule for semi/anti join exploration, when there is project between them (#13756)

This commit is contained in:
minghong
2022-11-01 19:07:25 +08:00
committed by GitHub
parent f30b974d54
commit 1eef986e75
4 changed files with 179 additions and 0 deletions

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinLogicalJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.SemiJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg;
import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows;
import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation;
@ -66,6 +67,7 @@ public class RuleSet {
.add(SemiJoinLogicalJoinTranspose.LEFT_DEEP)
.add(SemiJoinLogicalJoinTransposeProject.LEFT_DEEP)
.add(SemiJoinSemiJoinTranspose.INSTANCE)
.add(SemiJoinSemiJoinTransposeProject.INSTANCE)
.add(new AggregateDisassemble())
.add(new PushdownFilterThroughProject())
.add(new MergeProjects())

View File

@ -156,6 +156,7 @@ public enum RuleType {
LOGICAL_ASSERT_NUM_ROWS_TO_PHYSICAL_ASSERT_NUM_ROWS(RuleTypeClass.IMPLEMENTATION),
IMPLEMENTATION_SENTINEL(RuleTypeClass.IMPLEMENTATION),
LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT(RuleTypeClass.EXPLORATION),
// sentinel, use to count rules
SENTINEL(RuleTypeClass.SENTINEL),
;

View File

@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
/**
* rule for semi-semi transpose
*/
public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory {
public static final SemiJoinSemiJoinTransposeProject INSTANCE = new SemiJoinSemiJoinTransposeProject();
/*
* topSemi newTopSemi
* / \ / \
* abProject C acProject B
* / ──► /
* bottomSemi newBottomSemi
* / \ / \
* A B A C
*/
@Override
public Rule build() {
return logicalJoin(logicalProject(logicalJoin()), group())
.when(this::typeChecker)
.when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child()))
.then(topSemi -> {
LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child();
LogicalProject abProject = topSemi.left();
GroupPlan a = bottomSemi.left();
GroupPlan b = bottomSemi.right();
GroupPlan c = topSemi.right();
Set<Slot> aOutputSet = a.getOutputSet();
Set<NamedExpression> acProjects = new HashSet<NamedExpression>(abProject.getProjects());
bottomSemi.getHashJoinConjuncts().stream().forEach(
expression -> {
expression.getInputSlots().stream().forEach(
slot -> {
if (aOutputSet.contains(slot)) {
acProjects.add(slot);
}
});
}
);
LogicalJoin newBottomSemi = new LogicalJoin(topSemi.getJoinType(), topSemi.getHashJoinConjuncts(),
topSemi.getOtherJoinConjuncts(), a, c,
bottomSemi.getJoinReorderContext());
newBottomSemi.getJoinReorderContext().setHasCommute(false);
newBottomSemi.getJoinReorderContext().setHasLAsscom(false);
LogicalProject acProject = new LogicalProject(acProjects.stream().collect(Collectors.toList()),
newBottomSemi);
LogicalJoin newTopSemi = new LogicalJoin(bottomSemi.getJoinType(),
bottomSemi.getHashJoinConjuncts(), bottomSemi.getOtherJoinConjuncts(),
acProject, b,
topSemi.getJoinReorderContext());
newTopSemi.getJoinReorderContext().setHasLAsscom(true);
//return newTopSemi;
if (topSemi.getLogicalProperties().equals(newTopSemi)) {
return newTopSemi;
} else {
return new LogicalProject<>(new ArrayList<>(topSemi.getOutput()), newTopSemi);
}
}).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT);
}
public boolean typeChecker(LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin) {
return SemiJoinSemiJoinTranspose.VALID_TYPE_PAIR_SET
.contains(Pair.of(topJoin.getJoinType(), topJoin.left().child().getJoinType()));
}
}

View File

@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.exploration.join;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
public class SemiJoinSemiJoinTransposeProjectTest implements PatternMatchSupported {
public static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
public static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
public static final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0);
@Test
public void testSemiProjectSemiCommute() {
/*
* t1.name=t3.name t1.id=t2.id
* topJoin newTopJoin
* / \ / \
* project t3 project t2
* t1.name t1.name, t1.id
* | |
* t1.id=t2.id t1.name=t3.name
* bottomJoin --> newBottomJoin
* / \ / \
* t1 t2 t1 t3
*/
LogicalPlan topJoin = new LogicalPlanBuilder(scan1)
.hashJoinUsing(scan2, JoinType.LEFT_ANTI_JOIN, Pair.of(0, 0))
.project(ImmutableList.of(1))
.hashJoinUsing(scan3, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 1))
.build();
PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin)
.applyExploration(SemiJoinSemiJoinTransposeProject.INSTANCE.build())
.printlnExploration()
.matchesExploration(
logicalProject(
logicalJoin(
logicalProject(
logicalJoin(
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t3"))
).when(join -> join.getJoinType() == JoinType.LEFT_SEMI_JOIN)
).when(project -> project.getProjects().size() == 2
&& project.getProjects().get(0).getName().equals("id")
&& project.getProjects().get(1).getName().equals("name")
),
logicalOlapScan().when(scan -> scan.getTable().getName().equals("t2"))
).when(join -> join.getJoinType() == JoinType.LEFT_ANTI_JOIN)
)
);
}
}